Skip to content

Commit

Permalink
Upgrade Linters (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
kingychiu authored Feb 6, 2024
1 parent 5d8a012 commit 30f04fe
Show file tree
Hide file tree
Showing 6 changed files with 1,385 additions and 1,373 deletions.
2,706 changes: 1,362 additions & 1,344 deletions poetry.lock

Large diffs are not rendered by default.

18 changes: 12 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ beartype = "^0.14.1"
scipy = "^1.9"

[tool.poetry.group.dev.dependencies]
ruff = "^0.0.275"
black = "^23.3.0"
isort = "^5.12.0"
mypy = "^1.4.1"
ruff = "^0.2.1"
black = "^24.1.1"
isort = "^5.13.2"
mypy = "^1.8.0"
pytest = "^7.4.0"
pytest-cov = "^4.1.0"
datasets = "^2.13.1"
Expand All @@ -91,6 +91,12 @@ build-backend = "poetry.core.masonry.api"

[tool.ruff]
exclude = ["__init__.py"]
select = ["E", "F", "PL", "I", "ICN", "RET", "SIM", "NPY", "RUF"]
ignore = ["E501", "PLR0913"]
lint.select = ["E", "F", "PL", "I", "ICN", "RET", "SIM", "NPY", "RUF"]
lint.ignore = ["E501", "PLR0913"]

[tool.mypy]
exclude = [ ]
ignore_missing_imports = true
check_untyped_defs = true
python_version = "3.9"

6 changes: 3 additions & 3 deletions target_permutation_importances/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ def compute_permutation_importance_by_wasserstein_distance(
"importance"
].to_numpy(),
)
mean_actual_importance_df[
"wasserstein_distance"
] = mean_actual_importance_df.index.map(distances)
mean_actual_importance_df["wasserstein_distance"] = (
mean_actual_importance_df.index.map(distances)
)

# Sort by feature name to make sure the order is the same
mean_actual_importance_df = mean_actual_importance_df.sort_index()
Expand Down
1 change: 1 addition & 0 deletions target_permutation_importances/sklearn_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
The Sklearn Class Wrappers
"""

import numpy as np
import pandas as pd
from beartype.typing import Any, Dict, List, Union
Expand Down
21 changes: 7 additions & 14 deletions target_permutation_importances/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ class XBuilderType(Protocol): # pragma: no cover
return (XType): The X data
"""

def __call__(self, is_random_run: bool, run_idx: int) -> XType:
...
def __call__(self, is_random_run: bool, run_idx: int) -> XType: ...


@runtime_checkable
Expand All @@ -37,14 +36,12 @@ class YBuilderType(Protocol): # pragma: no cover
return (YType): The y data
"""

def __call__(self, is_random_run: bool, run_idx: int) -> YType:
...
def __call__(self, is_random_run: bool, run_idx: int) -> YType: ...


@runtime_checkable
class ModelFitParamsBuilderType(Protocol): # pragma: no cover
def __call__(self, feature_columns: Optional[List[str]]) -> dict:
...
def __call__(self, feature_columns: Optional[List[str]]) -> dict: ...


@runtime_checkable
Expand All @@ -60,8 +57,7 @@ class ModelBuilderType(Protocol): # pragma: no cover
return (Any): The newly created model
"""

def __call__(self, is_random_run: bool, run_idx: int) -> Any:
...
def __call__(self, is_random_run: bool, run_idx: int) -> Any: ...


@runtime_checkable
Expand All @@ -77,8 +73,7 @@ class ModelFitterType(Protocol): # pragma: no cover
return (Any): The fitted model
"""

def __call__(self, model: Any, X: XType, y: YType) -> Any:
...
def __call__(self, model: Any, X: XType, y: YType) -> Any: ...


@runtime_checkable
Expand All @@ -94,8 +89,7 @@ class ModelImportanceGetter(Protocol): # pragma: no cover
return (pd.DataFrame): The return DataFrame with columns ["feature", "importance"]
"""

def __call__(self, model: Any, X: XType, y: YType) -> pd.DataFrame:
...
def __call__(self, model: Any, X: XType, y: YType) -> pd.DataFrame: ...


@runtime_checkable
Expand All @@ -116,5 +110,4 @@ def __call__(
self,
actual_importance_dfs: List[pd.DataFrame],
random_importance_dfs: List[pd.DataFrame],
) -> pd.DataFrame:
...
) -> pd.DataFrame: ...
6 changes: 0 additions & 6 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,6 @@ def test_compute_multi_label_classification(model_cls, imp_func, xtype):
)
if xtype is pd.DataFrame:
X = pd.DataFrame(X, columns=[f"feature_{i}" for i in range(X.shape[1])])
else:
X = X

model_cls_params = model_cls[1].copy()
if "Cat" in model_cls[0].__name__:
Expand Down Expand Up @@ -174,8 +172,6 @@ def test_compute_multi_label_classification_with_MultiOutputClassifier(
)
if xtype is pd.DataFrame:
X = pd.DataFrame(X, columns=[f"feature_{i}" for i in range(X.shape[1])])
else:
X = X

if "Lasso" in model_cls[0].__name__ or "LinearSVC" in model_cls[0].__name__:
return
Expand Down Expand Up @@ -247,8 +243,6 @@ def test_compute_multi_target_regression_with_MultiOutputRegressor(
)
if xtype is pd.DataFrame:
X = pd.DataFrame(X, columns=[f"feature_{i}" for i in range(X.shape[1])])
else:
X = X

result_df = compute(
model_cls=MultiOutputRegressor,
Expand Down

0 comments on commit 30f04fe

Please sign in to comment.