diff --git a/Mikado/configuration/configuration.py b/Mikado/configuration/configuration.py index b4c01062e..1f028f25a 100644 --- a/Mikado/configuration/configuration.py +++ b/Mikado/configuration/configuration.py @@ -7,7 +7,7 @@ from .prepare_config import PrepareConfiguration from .serialise_config import SerialiseConfiguration from ..utilities.dbutils import DBConfiguration -from ..utilities.log_utils import LoggingConfiguration, create_null_logger +from ..utilities.log_utils import LoggingConfiguration, create_null_logger, create_default_logger from Mikado._transcripts.scoring_configuration import ScoringFile import os import yaml @@ -94,11 +94,11 @@ def load_scoring(self, logger=None): """ if logger is None: - logger = create_null_logger("check_scoring") + logger = create_default_logger("check_scoring", level="WARNING") if self.pick.scoring_file is None: if self._loaded_scoring != self.pick.scoring_file: logger.warning(f"Resetting the scoring to its previous value ({self._loaded_scoring})") - self.pick.scoring_file = self._loaded_scoring + self.pick.scoring_file = self._loaded_scoring = os.path.abspath(self._loaded_scoring) elif self._loaded_scoring != self.pick.scoring_file: logger.debug("Overwriting the scoring self using '%s' as scoring file", self.pick.scoring_file) self.scoring_file = None @@ -147,7 +147,8 @@ def load_scoring(self, logger=None): self.pick.scoring_file = option if found is True: logger.info("Found the correct option: %s", option) - self.pick.scoring_file = option + self.pick.scoring_file = os.path.abspath(option) + self._loaded_scoring = os.path.abspath(option) break if not found: raise InvalidConfiguration("No scoring configuration file found. Options: {}".format(",".join(options))) diff --git a/Mikado/tests/test_configurators.py b/Mikado/tests/test_configurators.py index 2ffa71e72..fd3fbda43 100644 --- a/Mikado/tests/test_configurators.py +++ b/Mikado/tests/test_configurators.py @@ -1,7 +1,13 @@ +import shutil import unittest +from dataclasses import asdict import marshmallow +import tempfile +import pkg_resources + +from .. import create_default_logger from ..configuration import configurator from .._transcripts.scoring_configuration import SizeFilter, TargetScore from ..configuration.configuration import * @@ -87,3 +93,34 @@ def test_load_plant_scoring(self): self.assertTrue(mammalian.scoring.requirements.compiled == plant.scoring.requirements.compiled != insect.scoring.requirements.compiled) self.assertTrue(mammalian.scoring.requirements.parameters != plant.scoring.requirements.parameters) + + def test_load_invalid_scoring(self): + erroneous = tempfile.NamedTemporaryFile(suffix=".yaml", mode="wt") + plant = MikadoConfiguration(pick=PickConfiguration(scoring_file="plant.yaml")) + err_scoring = copy.deepcopy(plant.scoring) + key = next(iter(err_scoring.scoring.keys())) + err_scoring.scoring[key].rescaling = "invalid" + key = next(iter(err_scoring.requirements.parameters.keys())) + del err_scoring.requirements.parameters[key] + yaml.dump(asdict(err_scoring), erroneous) + erroneous.flush() + plant.pick.scoring_file = erroneous.name + logger = create_default_logger("test_load_invalid_scoring", level="DEBUG") + with self.assertRaises(InvalidConfiguration): + plant.load_scoring(logger=logger) + current = os.getcwd() + os.chdir(os.path.dirname(erroneous.name)) + plant.scoring_file = os.path.basename(erroneous.name) + self.assertFalse(os.path.exists("plant.yaml")) + os.link(erroneous.name, "plant.yaml") + plant._loaded_scoring = None + plant.pick.scoring_file = "plant.yaml" + with self.assertRaises(InvalidConfiguration): + plant.load_scoring(logger=logger) + os.remove(os.path.join(tempfile.gettempdir(), "plant.yaml")) + plant._loaded_scoring = None + plant.load_scoring(logger=logger) + self.assertEqual(plant._loaded_scoring, + pkg_resources.resource_filename("Mikado.configuration", os.path.join("scoring_files", + "plant.yaml"))) + os.chdir(current)