Skip to content

Commit

Permalink
Rename estimator.
Browse files Browse the repository at this point in the history
  • Loading branch information
isaksamsten committed Jan 22, 2024
1 parent ed77b17 commit 70c5dee
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 47 deletions.
15 changes: 3 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
repos:
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
language_version: python3

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.0.259"
rev: v0.1.14
hooks:
- id: ruff

- repo: https://github.com/kynan/nbstripout
rev: 0.6.0
hooks:
- id: nbstripout
args: [ --fix ]
- id: ruff-format
4 changes: 4 additions & 0 deletions docs/more/whatsnew.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ Wildboar 1.2 introduces several new models.
as described by Guillaume et al., (2022).
- :class:`linear_model.DilatedShapeletClassifier`: a new shapelet based classifier
as described by Guillaume et al., (2022).
- :class:`transform.CastorClassifier`: a new shapelet based transform using
competing shapelets.
- :class:`linear_model.CastorClassifier`: a new shapelet based classifier using
competing shapelets.

Changelog
---------
Expand Down
4 changes: 2 additions & 2 deletions src/wildboar/linear_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ._hydra import HydraClassifier
from ._rocket import RocketClassifier, RocketRegressor
from ._shapelet import (
CompetingDilatedShapeletClassifier,
CastorClassifier,
DilatedShapeletClassifier,
RandomShapeletClassifier,
RandomShapeletRegressor,
Expand All @@ -19,5 +19,5 @@
"RandomShapeletClassifier",
"RandomShapeletRegressor",
"DilatedShapeletClassifier",
"CompetingDilatedShapeletClassifier",
"CastorClassifier",
]
6 changes: 3 additions & 3 deletions src/wildboar/linear_model/_rocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class RocketClassifier(TransformRidgeClassifierCV):
"""Implements the ROCKET classifier."""

def __init__(
def __init__( # noqa: PLR0913
self,
n_kernels=10000,
*,
Expand All @@ -25,7 +25,7 @@ def __init__(
class_weight=None,
normalize=True,
n_jobs=None,
random_state=None
random_state=None,
):
super().__init__(
alphas=alphas,
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
cv=None,
gcv_mode=None,
n_jobs=None,
random_state=None
random_state=None,
):
super().__init__(
alphas=alphas,
Expand Down
14 changes: 7 additions & 7 deletions src/wildboar/linear_model/_shapelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ..datasets.preprocess import SparseScaler
from ..transform import (
CompetingDilatedShapeletTransform,
CastorTransform,
DiffTransform,
DilatedShapeletTransform,
RandomShapeletTransform,
Expand Down Expand Up @@ -235,7 +235,7 @@ def _get_transform(self, random_state):
)


class CompetingDilatedShapeletClassifier(TransformRidgeClassifierCV):
class CastorClassifier(TransformRidgeClassifierCV):
"""
A dictionary based method using dilated competing shapelets.
Expand Down Expand Up @@ -355,7 +355,7 @@ def _build_pipeline(self):

return pipeline

def _get_transform(self, random_state):
def _get_transform(self, random_state): # noqa: PLR0912
random_state = check_random_state(random_state)
params = dict(
n_shapelets=self.n_shapelets,
Expand Down Expand Up @@ -396,7 +396,7 @@ def _get_transform(self, random_state):
ng += 1

union.append(
CompetingDilatedShapeletTransform(
CastorTransform(
n_groups=ng,
shapelet_size=size,
random_state=random_state.randint(np.iinfo(np.int32).max),
Expand All @@ -406,7 +406,7 @@ def _get_transform(self, random_state):
union.append(
make_pipeline(
DiffTransform(order=self.order),
CompetingDilatedShapeletTransform(
CastorTransform(
n_groups=ng,
shapelet_size=size,
random_state=random_state.randint(np.iinfo(np.int32).max),
Expand All @@ -422,7 +422,7 @@ def _get_transform(self, random_state):
shapelet_size = self.shapelet_size

if len(shapelet_size) == 1:
return CompetingDilatedShapeletTransform(
return CastorTransform(
n_groups=self.n_groups,
shapelet_size=shapelet_size[0],
random_state=random_state,
Expand All @@ -444,7 +444,7 @@ def _get_transform(self, random_state):
ng += 1

union.append(
CompetingDilatedShapeletTransform(
CastorTransform(
n_groups=ng,
shapelet_size=size,
random_state=random_state.randint(np.iinfo(np.int32).max),
Expand Down
4 changes: 2 additions & 2 deletions src/wildboar/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
symbolic_aggregate_approximation,
)
from ._shapelet import (
CompetingDilatedShapeletTransform,
CastorTransform,
DilatedShapeletTransform,
RandomShapeletTransform,
)
Expand All @@ -38,5 +38,5 @@
"HydraTransform",
"DiffTransform",
"DerivativeTransform",
"CompetingDilatedShapeletTransform",
"CastorTransform",
]
11 changes: 8 additions & 3 deletions src/wildboar/transform/_cshapelet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ cdef class CastorSummarizer:
self.max_values = NULL
self.min_so_values = NULL

def __init__(self, soft_min=False, soft_max=True, soft_threshold=False):
def __init__(self, soft_min=True, soft_max=False, soft_threshold=True):
self.soft_min = soft_min
self.soft_max = soft_max
self.soft_threshold = soft_threshold
Expand All @@ -849,6 +849,9 @@ cdef class CastorSummarizer:
def __dealloc__(self):
self._free()

cdef Py_ssize_t get_n_features(self) noexcept nogil:
return 3

cdef void _free(self) noexcept nogil:
if self.min_values != NULL:
free(self.min_values)
Expand Down Expand Up @@ -943,7 +946,7 @@ cdef class CastorSummarizer:
max_value[0] = value
max_index[0] = i

cdef class CompetingDilatedShapeletAttributeGenerator(AttributeGenerator):
cdef class CastorAttributeGenerator(AttributeGenerator):
cdef Py_ssize_t n_shapelets
cdef Py_ssize_t shapelet_size
cdef Py_ssize_t n_groups
Expand Down Expand Up @@ -1081,7 +1084,9 @@ cdef class CompetingDilatedShapeletAttributeGenerator(AttributeGenerator):
# n_shapelets * max_exponent * n_shapelets * 3 features per time series.
cdef Py_ssize_t get_n_outputs(self, TSArray X) noexcept nogil:
cdef Py_ssize_t max_exponent = _max_exponent(X.shape[2], self.shapelet_size)
return self.get_n_attributess(X) * max_exponent * self.n_shapelets * 3
return (
self.get_n_attributess(X) * max_exponent * self.n_shapelets * self.summarizer.get_n_features()
)

cdef Py_ssize_t _get_distance_profile(
self,
Expand Down
16 changes: 7 additions & 9 deletions src/wildboar/transform/_shapelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from ..distance._multi_metric import make_subsequence_metrics
from ._base import BaseAttributeTransform
from ._cshapelet import (
CastorAttributeGenerator,
CastorSummarizer,
CompetingDilatedShapeletAttributeGenerator,
DilatedShapeletAttributeGenerator,
RandomMultiMetricShapeletAttributeGenerator,
RandomShapeletAttributeGenerator,
Expand Down Expand Up @@ -41,7 +41,7 @@ class ShapeletMixin:
],
}

def _get_generator(self, x, y):
def _get_generator(self, x, y): # noqa: PLR0912
if self.min_shapelet_size > self.max_shapelet_size:
raise ValueError(
f"The min_shapelet_size parameter of {type(self).__qualname__} "
Expand Down Expand Up @@ -184,7 +184,7 @@ def _get_generator(self, x, y):
)


class CompetingDilatedShapeletMixin:
class CastorMixin:
_parameter_constraints = {
"n_groups": [Interval(numbers.Integral, 1, None, closed="left")],
"n_shapelets": [Interval(numbers.Integral, 1, None, closed="left")],
Expand Down Expand Up @@ -213,7 +213,7 @@ def _get_generator(self, x, y):
samples_per_label = None
y = None

return CompetingDilatedShapeletAttributeGenerator(
return CastorAttributeGenerator(
self.n_groups,
self.n_shapelets,
_odd_shapelet_size(self.shapelet_size, x.shape[-1]),
Expand All @@ -228,11 +228,9 @@ def _get_generator(self, x, y):
)


class CompetingDilatedShapeletTransform(
CompetingDilatedShapeletMixin, BaseAttributeTransform
):
class CastorTransform(CastorMixin, BaseAttributeTransform):
"""
A dictionary based method using dilated competing shapelets.
Competing Dialated Shapelet Transform.
Parameters
----------
Expand Down Expand Up @@ -276,7 +274,7 @@ class CompetingDilatedShapeletTransform(
"""

_parameter_constraints = {
**CompetingDilatedShapeletMixin._parameter_constraints,
**CastorMixin._parameter_constraints,
**BaseAttributeTransform._parameter_constraints,
}

Expand Down
10 changes: 3 additions & 7 deletions tests/wildboar/transform/test_dilated_shapelet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from wildboar.datasets import load_gun_point
from wildboar.transform import (
CompetingDilatedShapeletTransform,
CastorTransform,
DilatedShapeletTransform,
)

Expand Down Expand Up @@ -43,9 +43,7 @@ def test_dilated_shapelet_supervised_transform():

def test_competing_dilated_shapelet_unsupervised_transform():
X, y = load_gun_point()
f = CompetingDilatedShapeletTransform(
random_state=1, n_groups=1, normalize_prob=0.5
)
f = CastorTransform(random_state=1, n_groups=1, normalize_prob=0.5)
f.fit(X)

# fmt: off
Expand Down Expand Up @@ -77,9 +75,7 @@ def test_competing_dilated_shapelet_unsupervised_transform():

def test_competing_dilated_shapelet_supervised_transform():
X, y = load_gun_point()
f = CompetingDilatedShapeletTransform(
random_state=1, n_groups=1, normalize_prob=0.5
)
f = CastorTransform(random_state=1, n_groups=1, normalize_prob=0.5)
f.fit(X, y)

# fmt: off
Expand Down
4 changes: 2 additions & 2 deletions tests/wildboar/transform/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from wildboar.transform import (
PAA,
SAX,
CompetingDilatedShapeletTransform,
CastorTransform,
DilatedShapeletTransform,
HydraTransform,
IntervalTransform,
Expand Down Expand Up @@ -33,7 +33,7 @@
(PAA(), []),
(HydraTransform(), []),
(DilatedShapeletTransform(), []),
(CompetingDilatedShapeletTransform(), []),
(CastorTransform(), []),
],
)
def test_estimator_checks(estimator, skip):
Expand Down

0 comments on commit 70c5dee

Please sign in to comment.