Skip to content

Commit

Permalink
Fix incompatibility errors raises()
Browse files Browse the repository at this point in the history
Different in hydra >= 1.1.2
  • Loading branch information
dunnkers committed Jul 9, 2022
1 parent 4209705 commit a450bba
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
16 changes: 8 additions & 8 deletions tests/integration/pipelines/test_rank_and_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@
import numpy as np
import pandas as pd
import pytest
from hydra.core.config_store import ConfigStore
from hydra.utils import instantiate
from omegaconf import DictConfig, open_dict
from sklearn.base import BaseEstimator

from fseval.config import (
CrossValidatorConfig,
DatasetConfig,
Expand All @@ -16,8 +11,13 @@
ResampleConfig,
)
from fseval.pipeline.dataset import Dataset, DatasetLoader
from fseval.types import AbstractAdapter, IncompatibilityError, Task
from fseval.types import AbstractAdapter, Task
from fseval.utils.hydra_utils import get_config
from hydra.core.config_store import ConfigStore
from hydra.errors import InstantiationException
from hydra.utils import instantiate
from omegaconf import DictConfig, open_dict
from sklearn.base import BaseEstimator

cs = ConfigStore.instance()

Expand Down Expand Up @@ -206,7 +206,7 @@ def test_with_ranker_gt_no_importances_substitution(cfg: PipelineConfig):


def test_validator_incompatibility_check(cfg: PipelineConfig):
with pytest.raises(IncompatibilityError):
with pytest.raises(InstantiationException):
cfg.dataset.n = 5
cfg.dataset.p = 5
cfg.dataset.multioutput = False
Expand All @@ -215,7 +215,7 @@ def test_validator_incompatibility_check(cfg: PipelineConfig):


def test_ranker_incompatibility_check(cfg: PipelineConfig):
with pytest.raises(IncompatibilityError):
with pytest.raises(InstantiationException):
cfg.dataset.n = 5
cfg.dataset.p = 5
cfg.dataset.multioutput = False
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fseval.types import IncompatibilityError
from fseval.utils.hydra_utils import get_config
from hydra.conf import ConfigStore
from hydra.errors import InstantiationException


@pytest.fixture
Expand Down Expand Up @@ -66,7 +67,7 @@ def test_pipeline_incompatibility(incompatible_cfg: PipelineConfig):
"""Pipeline should throw IncompatibilityError when trying to run a classification
method on a regression dataset."""

with pytest.raises(IncompatibilityError):
with pytest.raises(InstantiationException):
run_pipeline(incompatible_cfg, raise_incompatibility_errors=True)


Expand Down
9 changes: 5 additions & 4 deletions tests/unit/pipeline/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

import pytest
from hydra.utils import instantiate
from hydra.errors import InstantiationException
from omegaconf import OmegaConf
from sklearn.base import BaseEstimator

from fseval.config import EstimatorConfig
from fseval.pipeline.estimator import Estimator
from fseval.storage.local import LocalStorage
from fseval.types import CacheUsage, IncompatibilityError, Task
from fseval.types import CacheUsage, Task


@pytest.fixture
Expand Down Expand Up @@ -79,23 +80,23 @@ def test_incompatibility(estimator_cfg: EstimatorConfig):
# classification estimator, but regression task
estimator_cfg._estimator_type = "classifier"
estimator_cfg.task = Task.regression
with pytest.raises(IncompatibilityError):
with pytest.raises(InstantiationException):
instantiate(estimator_cfg)

# multioutput, but estimator does not support it (`multioutput=False`)
estimator_cfg._estimator_type = "classifier"
estimator_cfg.task = Task.classification
estimator_cfg.multioutput = False
estimator_cfg.is_multioutput_dataset = True
with pytest.raises(IncompatibilityError):
with pytest.raises(InstantiationException):
instantiate(estimator_cfg)

# multioutput only, but
estimator_cfg._estimator_type = "classifier"
estimator_cfg.task = Task.classification
estimator_cfg.multioutput_only = True
estimator_cfg.is_multioutput_dataset = False
with pytest.raises(IncompatibilityError):
with pytest.raises(InstantiationException):
instantiate(estimator_cfg)


Expand Down

0 comments on commit a450bba

Please sign in to comment.