diff --git a/.github/workflows/preprocessing-tests.yaml b/.github/workflows/preprocessing-tests.yaml index ba17cdcd3..3b52ee54a 100644 --- a/.github/workflows/preprocessing-tests.yaml +++ b/.github/workflows/preprocessing-tests.yaml @@ -28,6 +28,6 @@ jobs: - name: Run tests run: | pip install -e '.[test]' - python3 tests/test.py + pytest shell: micromamba-shell {0} working-directory: preprocessing/nextclade/ diff --git a/preprocessing/nextclade/.justfile b/preprocessing/nextclade/.justfile index 9cac9bf3a..3cb2f0364 100644 --- a/preprocessing/nextclade/.justfile +++ b/preprocessing/nextclade/.justfile @@ -1,5 +1,14 @@ all: ruff_format ruff_check run_mypy +create_env: + micromamba create -f environment.yml --rc-file .mambarc + +install: + pip install -e . + +install_test: + pip install -e .[test] + r: ruff ruff: ruff_check ruff_format diff --git a/preprocessing/nextclade/README.md b/preprocessing/nextclade/README.md index e5b490be6..e8082ce8b 100644 --- a/preprocessing/nextclade/README.md +++ b/preprocessing/nextclade/README.md @@ -15,18 +15,20 @@ This preprocessing pipeline is still a work in progress. It requests unaligned n ## Setup -### Start directly +### Installation 1. Install `conda`/`mamba`/`micromamba`: see e.g. [micromamba installation docs](https://mamba.readthedocs.io/en/latest/micromamba-installation.html#umamba-install) -2. Install environment: +1. Install environment: - ```bash + ```sh mamba env create -n loculus-nextclade -f environment.yml ``` -3. Start backend (see [backend README](../backend/README.md)), run ingest script to submit sequences from INSDC. (Alternatively you can run `./deploy.py --enablePreprocessing` to start the backend and preprocessing pods in one command.) +### Running + +1. Start backend (see [backend README](../backend/README.md)), run ingest script to submit sequences from INSDC. (Alternatively you can run `./deploy.py --enablePreprocessing` to start the backend and preprocessing pods in one command.) -4. Run pipeline +1. Run pipeline ```bash mamba activate loculus-nextclade @@ -38,10 +40,10 @@ This preprocessing pipeline is still a work in progress. It requests unaligned n Tests can be run from the same directory -```bash +```sh mamba activate loculus-nextclade -pip install -e . -python3 tests/test.py +pip install -e '.[test]' +pytest ``` Note that we do not add the tests folder to the docker image. In the CI tests are run using the same mamba environment as the preprocessing docker image but do not use the actual docker image. We chose this approach as it makes the CI tests faster but could potentially lead to the tests using a different program version than used in the docker image. @@ -66,13 +68,13 @@ docker run -it --platform=linux/amd64 --network host --rm nextclade_processing p When deployed on kubernetes the preprocessing pipeline reads in config files which are created by `loculus/kubernetes/loculus/templates/loculus-preprocessing-config.yaml`. When run locally the pipeline uses only the default values defined in `preprocessing/nextclade/src/loculus_preprocessing/config.py`. When running the preprocessing pipeline locally it makes sense to create a local config file using the command: -``` +```sh ../../generate_local_test_config.sh ``` and use this in the pipeline as follows: -``` +```sh prepro --config-file=../../temp/preprocessing-config.{organism}.yaml --keep-tmp-dir ``` @@ -103,7 +105,7 @@ However, the `preprocessing` field can be customized to take an arbitrary number Using these functions in your `values.yaml` will look like: -``` +```yaml - name: sampleCollectionDate type: date preprocessing: diff --git a/preprocessing/nextclade/pyproject.toml b/preprocessing/nextclade/pyproject.toml index ae8e1b3a2..0190efef2 100644 --- a/preprocessing/nextclade/pyproject.toml +++ b/preprocessing/nextclade/pyproject.toml @@ -15,4 +15,4 @@ build-backend = "hatchling.build" packages = ["src/loculus_preprocessing"] [project.optional-dependencies] -test = ["pytest"] \ No newline at end of file +test = ["pytest", "mypy", "types-pytz"] \ No newline at end of file diff --git a/preprocessing/nextclade/src/loculus_preprocessing/backend.py b/preprocessing/nextclade/src/loculus_preprocessing/backend.py index 9dbdc4883..a407a9eb5 100644 --- a/preprocessing/nextclade/src/loculus_preprocessing/backend.py +++ b/preprocessing/nextclade/src/loculus_preprocessing/backend.py @@ -140,10 +140,9 @@ def submit_processed_sequences( if not response.ok: Path("failed_submission.json").write_text(ndjson_string, encoding="utf-8") msg = ( - f"Submitting processed data failed. Status code: { - response.status_code}\n" + f"Submitting processed data failed. Status code: {response.status_code}\n" f"Response: {response.text}\n" - f"Data sent in request: {ndjson_string[0:1000]}...\n" + f"Data sent: {ndjson_string[:1000]}...\n" ) raise RuntimeError(msg) logging.info("Processed data submitted successfully") diff --git a/preprocessing/nextclade/src/loculus_preprocessing/config.py b/preprocessing/nextclade/src/loculus_preprocessing/config.py index 3970acafb..b2fa2a568 100644 --- a/preprocessing/nextclade/src/loculus_preprocessing/config.py +++ b/preprocessing/nextclade/src/loculus_preprocessing/config.py @@ -18,7 +18,7 @@ @dataclass class Config: organism: str = "mpox" - backend_host: str = "" # populated in get_config if left empty + backend_host: str = "" # populated in get_config if left empty, so we can use organism keycloak_host: str = "http://127.0.0.1:8083" keycloak_user: str = "preprocessing_pipeline" keycloak_password: str = "preprocessing_pipeline" @@ -29,7 +29,7 @@ class Config: config_file: str | None = None log_level: str = "DEBUG" genes: list[str] = dataclasses.field(default_factory=list) - nucleotideSequences: list[str] = dataclasses.field(default_factory=lambda: ["main"]) + nucleotideSequences: list[str] = dataclasses.field(default_factory=lambda: ["main"]) # noqa: N815 keep_tmp_dir: bool = False reference_length: int = 197209 batch_size: int = 5 @@ -37,8 +37,8 @@ class Config: pipeline_version: int = 1 -def load_config_from_yaml(config_file: str, config: Config) -> Config: - config = copy.deepcopy(config) +def load_config_from_yaml(config_file: str, config: Config | None = None) -> Config: + config = Config() if config is None else copy.deepcopy(config) with open(config_file, encoding="utf-8") as file: yaml_config = yaml.safe_load(file) logging.debug(f"Loaded config from {config_file}: {yaml_config}") @@ -78,8 +78,14 @@ def generate_argparse_from_dataclass(config_cls: type[Config]) -> argparse.Argum def get_config(config_file: str | None = None) -> Config: - # Config precedence: CLI args > ENV variables > config file > default + """ + Config precedence: Direct function args > CLI args > ENV variables > config file > default + args: + config_file: Path to YAML config file - only used by tests + """ + + # Set just log level this early from env, so we can debug log during config loading env_log_level = os.environ.get("PREPROCESSING_LOG_LEVEL") if env_log_level: logging.basicConfig(level=env_log_level) @@ -87,16 +93,13 @@ def get_config(config_file: str | None = None) -> Config: parser = generate_argparse_from_dataclass(Config) args = parser.parse_args() - # Load default config - config = Config() + # Use first config file present in order of precedence + config_file_path = ( + config_file or args.config_file or os.environ.get("PREPROCESSING_CONFIG_FILE") + ) - # Overwrite config with config in config_file - if config_file: - config = load_config_from_yaml(config_file, config) - if args.config_file: - config = load_config_from_yaml(args.config_file, config) - if not config.backend_host: # Check if backend_host wasn't set during initialization - config.backend_host = f"http://127.0.0.1:8079/{config.organism}" + # Start with lowest precedence config, then overwrite with higher precedence + config = load_config_from_yaml(config_file_path) if config_file_path else Config() # Use environment variables if available for key in config.__dict__: @@ -109,4 +112,7 @@ def get_config(config_file: str | None = None) -> Config: if value is not None: setattr(config, key, value) + if not config.backend_host: # Set here so we can use organism + config.backend_host = f"http://127.0.0.1:8079/{config.organism}" + return config diff --git a/preprocessing/nextclade/src/loculus_preprocessing/datatypes.py b/preprocessing/nextclade/src/loculus_preprocessing/datatypes.py index 8eaf34b09..aa08c273a 100644 --- a/preprocessing/nextclade/src/loculus_preprocessing/datatypes.py +++ b/preprocessing/nextclade/src/loculus_preprocessing/datatypes.py @@ -50,12 +50,12 @@ def __hash__(self): class UnprocessedData: submitter: str metadata: InputMetadata - unalignedNucleotideSequences: dict[str, NucleotideSequence] + unalignedNucleotideSequences: dict[str, NucleotideSequence] # noqa: N815 @dataclass class UnprocessedEntry: - accessionVersion: AccessionVersion # {accession}.{version} + accessionVersion: AccessionVersion # {accession}.{version} # noqa: N815 data: UnprocessedData @@ -74,25 +74,25 @@ class ProcessingSpec: # For single segment, need to generalize for multi segments later @dataclass class UnprocessedAfterNextclade: - inputMetadata: InputMetadata + inputMetadata: InputMetadata # noqa: N815 # Derived metadata produced by Nextclade - nextcladeMetadata: dict[SegmentName, Any] | None - unalignedNucleotideSequences: dict[SegmentName, NucleotideSequence | None] - alignedNucleotideSequences: dict[SegmentName, NucleotideSequence | None] - nucleotideInsertions: dict[SegmentName, list[NucleotideInsertion]] - alignedAminoAcidSequences: dict[GeneName, AminoAcidSequence | None] - aminoAcidInsertions: dict[GeneName, list[AminoAcidInsertion]] + nextcladeMetadata: dict[SegmentName, Any] | None # noqa: N815 + unalignedNucleotideSequences: dict[SegmentName, NucleotideSequence | None] # noqa: N815 + alignedNucleotideSequences: dict[SegmentName, NucleotideSequence | None] # noqa: N815 + nucleotideInsertions: dict[SegmentName, list[NucleotideInsertion]] # noqa: N815 + alignedAminoAcidSequences: dict[GeneName, AminoAcidSequence | None] # noqa: N815 + aminoAcidInsertions: dict[GeneName, list[AminoAcidInsertion]] # noqa: N815 errors: list[ProcessingAnnotation] @dataclass class ProcessedData: metadata: ProcessedMetadata - unalignedNucleotideSequences: dict[str, Any] - alignedNucleotideSequences: dict[str, Any] - nucleotideInsertions: dict[str, Any] - alignedAminoAcidSequences: dict[str, Any] - aminoAcidInsertions: dict[str, Any] + unalignedNucleotideSequences: dict[str, Any] # noqa: N815 + alignedNucleotideSequences: dict[str, Any] # noqa: N815 + nucleotideInsertions: dict[str, Any] # noqa: N815 + alignedAminoAcidSequences: dict[str, Any] # noqa: N815 + aminoAcidInsertions: dict[str, Any] # noqa: N815 @dataclass diff --git a/preprocessing/nextclade/src/loculus_preprocessing/prepro.py b/preprocessing/nextclade/src/loculus_preprocessing/prepro.py index a51289952..07d9fa6f7 100644 --- a/preprocessing/nextclade/src/loculus_preprocessing/prepro.py +++ b/preprocessing/nextclade/src/loculus_preprocessing/prepro.py @@ -94,12 +94,14 @@ def parse_nextclade_tsv( def parse_nextclade_json( result_dir, - nextclade_metadata: defaultdict[AccessionVersion, defaultdict[SegmentName, dict[str, Any]]], + nextclade_metadata: defaultdict[ + AccessionVersion, defaultdict[SegmentName, dict[str, Any] | None] + ], segment: SegmentName, unaligned_nucleotide_sequences: dict[ AccessionVersion, dict[SegmentName, NucleotideSequence | None] ], -) -> defaultdict[AccessionVersion, defaultdict[SegmentName, dict[str, Any]]]: +) -> defaultdict[AccessionVersion, defaultdict[SegmentName, dict[str, Any] | None]]: """ Update nextclade_metadata object with the results of the nextclade analysis. If the segment existed in the input (unaligned_nucleotide_sequences) but did not align @@ -116,7 +118,7 @@ def parse_nextclade_json( return nextclade_metadata -def enrich_with_nextclade( +def enrich_with_nextclade( # noqa: C901, PLR0912, PLR0914, PLR0915 unprocessed: Sequence[UnprocessedEntry], dataset_dir: str, config: Config ) -> dict[AccessionVersion, UnprocessedAfterNextclade]: """ @@ -166,12 +168,14 @@ def enrich_with_nextclade( error_dict[id] = error_dict.get(id, []) error_dict[id].append( ProcessingAnnotation( - source=[ - AnnotationSource( - name=segment, - type=AnnotationSourceType.NUCLEOTIDE_SEQUENCE, + source=( + ( + AnnotationSource( + name=segment, + type=AnnotationSourceType.NUCLEOTIDE_SEQUENCE, + ), ) - ], + ), message="Found multiple sequences with the same segment name.", ) ) @@ -191,12 +195,12 @@ def enrich_with_nextclade( error_dict[id] = error_dict.get(id, []) error_dict[id].append( ProcessingAnnotation( - source=[ + source=( AnnotationSource( name="main", type=AnnotationSourceType.NUCLEOTIDE_SEQUENCE, - ) - ], + ), + ), message=( "Found unknown segments in the input data - " "check your segments are annotated correctly." @@ -204,9 +208,9 @@ def enrich_with_nextclade( ) ) - nextclade_metadata: defaultdict[AccessionVersion, defaultdict[SegmentName, dict[str, Any]]] = ( - defaultdict(lambda: defaultdict(dict)) - ) + nextclade_metadata: defaultdict[ + AccessionVersion, defaultdict[SegmentName, dict[str, Any] | None] + ] = defaultdict(lambda: defaultdict(dict)) nucleotide_insertions: defaultdict[ AccessionVersion, defaultdict[SegmentName, list[NucleotideInsertion]] ] = defaultdict(lambda: defaultdict(list)) @@ -234,11 +238,10 @@ def enrich_with_nextclade( "run", f"--output-all={result_dir_seg}", f"--input-dataset={dataset_dir_seg}", - f"--output-translations={ - result_dir_seg}/nextclade.cds_translation.{{cds}}.fasta", + f"--output-translations={result_dir_seg}/nextclade.cds_translation.{{cds}}.fasta", "--jobs=1", "--", - f"{input_file}", + input_file, ] logging.debug(f"Running nextclade: {command}") @@ -372,12 +375,12 @@ def add_input_metadata( unprocessed: UnprocessedAfterNextclade, errors: list[ProcessingAnnotation], input_path: str, -) -> InputMetadata: +) -> str | None: """Returns value of input_path in unprocessed metadata""" # If field starts with "nextclade.", take from nextclade metadata nextclade_prefix = "nextclade." if input_path.startswith(nextclade_prefix): - segment = spec.args.get("segment", "main") + segment = spec.args["segment"] if spec.args and "segment" in spec.args else "main" if not unprocessed.nextcladeMetadata: # This field should never be empty message = ( @@ -386,12 +389,12 @@ def add_input_metadata( ) errors.append( ProcessingAnnotation( - source=[ + source=( AnnotationSource( name="segment", type=AnnotationSourceType.NUCLEOTIDE_SEQUENCE, - ) - ], + ), + ), message=message, ) ) @@ -406,17 +409,17 @@ def add_input_metadata( ) errors.append( ProcessingAnnotation( - source=[ + source=( AnnotationSource( name=segment, type=AnnotationSourceType.NUCLEOTIDE_SEQUENCE, - ) - ], + ), + ), message=message, ) ) return None - result = str( + result: str | None = str( dpath.get( unprocessed.nextcladeMetadata[segment], sub_path, @@ -447,7 +450,7 @@ def add_input_metadata( return unprocessed.inputMetadata[input_path] -def get_metadata( +def get_metadata( # noqa: PLR0913, PLR0917 id: AccessionVersion, spec: ProcessingSpec, output_field: str, @@ -491,7 +494,7 @@ def get_metadata( return processing_result -def processed_entry_no_alignment( +def processed_entry_no_alignment( # noqa: PLR0913, PLR0917 id: AccessionVersion, unprocessed: UnprocessedData, config: Config, @@ -538,7 +541,7 @@ def processed_entry_no_alignment( ) -def process_single( +def process_single( # noqa: C901 id: AccessionVersion, unprocessed: UnprocessedAfterNextclade | UnprocessedData, config: Config ) -> ProcessedEntry: """Process a single sequence per config""" @@ -751,7 +754,8 @@ def run(config: Config) -> None: logging.debug("No unprocessed sequences found. Sleeping for 1 second.") time.sleep(1) continue - # Don't use etag if we just got data, preprocessing only asks for 100 sequences to process at a time, so there might be more + # Don't use etag if we just got data + # preprocessing only asks for 100 sequences to process at a time, so there might be more etag = None try: processed = process_all(unprocessed, dataset_dir, config) diff --git a/preprocessing/nextclade/src/loculus_preprocessing/processing_functions.py b/preprocessing/nextclade/src/loculus_preprocessing/processing_functions.py index 216d2b614..663b2ff80 100644 --- a/preprocessing/nextclade/src/loculus_preprocessing/processing_functions.py +++ b/preprocessing/nextclade/src/loculus_preprocessing/processing_functions.py @@ -162,7 +162,7 @@ def check_date( ) @staticmethod - def parse_and_assert_past_date( + def parse_and_assert_past_date( # noqa: C901 input_data: InputMetadata, output_field, args: FunctionArgs = None, # args is essential - even if Pylance says it's not used @@ -245,7 +245,10 @@ def parse_and_assert_past_date( name=output_field, type=AnnotationSourceType.METADATA ) ], - message=f"Metadata field {output_field}:'{date_str}' is after release date.", + message=( + f"Metadata field {output_field}:'{date_str}'" + "is after release date." + ), ) ) @@ -399,7 +402,7 @@ def concatenate( ) @staticmethod - def identity( + def identity( # noqa: C901, PLR0912 input_data: InputMetadata, output_field: str, args: FunctionArgs = None ) -> ProcessingResult: """Identity function, takes input_data["input"] and returns it as output""" diff --git a/preprocessing/nextclade/src/loculus_preprocessing/sequence_checks.py b/preprocessing/nextclade/src/loculus_preprocessing/sequence_checks.py index 7a2553a1f..9713ec194 100644 --- a/preprocessing/nextclade/src/loculus_preprocessing/sequence_checks.py +++ b/preprocessing/nextclade/src/loculus_preprocessing/sequence_checks.py @@ -40,7 +40,10 @@ def errors_if_non_iupac( name=segment, type=AnnotationSourceType.NUCLEOTIDE_SEQUENCE ) ], - message=f"Found non-IUPAC symbols in the {segment} sequence: {', '.join(non_iupac_symbols)}", + message=( + f"Found non-IUPAC symbols in the {segment} sequence: " + + ", ".join(non_iupac_symbols) + ), ) ) return errors diff --git a/preprocessing/nextclade/tests/factory_methods.py b/preprocessing/nextclade/tests/factory_methods.py index 9538e28ed..13b03b8b9 100644 --- a/preprocessing/nextclade/tests/factory_methods.py +++ b/preprocessing/nextclade/tests/factory_methods.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass from loculus_preprocessing.datatypes import ( AnnotationSource, @@ -12,23 +12,22 @@ @dataclass -class TestCase: +class ProcessingTestCase: name: str input: UnprocessedEntry expected_output: ProcessedEntry + @dataclass class UnprocessedEntryFactory: - _counter: int = 0 @staticmethod def create_unprocessed_entry( metadata_dict: dict[str, str], + accession_id: str, ) -> UnprocessedEntry: - unique_id = str(UnprocessedEntryFactory._counter) - UnprocessedEntryFactory._counter += 1 return UnprocessedEntry( - accessionVersion="LOC_" + unique_id + ".1", + accessionVersion=f"LOC_{accession_id}.1", data=UnprocessedData( submitter="test_submitter", metadata=metadata_dict, @@ -39,16 +38,16 @@ def create_unprocessed_entry( @dataclass class ProcessedEntryFactory: - _counter: int = 0 - _all_metadata_fields: list[str] | None = field(default=None) + all_metadata_fields: list[str] | None = None - def __init__(self, all_metadata_fields: list[str] | None = None): - if all_metadata_fields is not None: - self._all_metadata_fields = all_metadata_fields + def __post_init__(self): + if self.all_metadata_fields is None: + self.all_metadata_fields = [] def create_processed_entry( self, metadata_dict: dict[str, str], + accession: str, metadata_errors: list[tuple[str, str]] | None = None, metadata_warnings: list[tuple[str, str]] | None = None, ) -> ProcessedEntry: @@ -56,15 +55,12 @@ def create_processed_entry( metadata_errors = [] if metadata_warnings is None: metadata_warnings = [] - if self._all_metadata_fields: - base_metadata_dict = dict.fromkeys(self._all_metadata_fields) - base_metadata_dict.update(metadata_dict) - else: - base_metadata_dict = metadata_dict - unique_id = str(ProcessedEntryFactory._counter) - ProcessedEntryFactory._counter += 1 + + base_metadata_dict = dict.fromkeys(self.all_metadata_fields) + base_metadata_dict.update(metadata_dict) + return ProcessedEntry( - accession="LOC_" + unique_id, + accession=accession, version=1, data=ProcessedData( metadata=base_metadata_dict, @@ -88,4 +84,4 @@ def create_processed_entry( ) for warning in metadata_warnings ], - ) \ No newline at end of file + ) diff --git a/preprocessing/nextclade/tests/test.py b/preprocessing/nextclade/tests/test.py deleted file mode 100644 index 1b77028aa..000000000 --- a/preprocessing/nextclade/tests/test.py +++ /dev/null @@ -1,387 +0,0 @@ -import unittest - -from factory_methods import ProcessedEntryFactory, TestCase, UnprocessedEntryFactory - -from loculus_preprocessing.config import Config, get_config -from loculus_preprocessing.datatypes import ( - ProcessedEntry, - ProcessingAnnotation, -) -from loculus_preprocessing.prepro import process_all -from loculus_preprocessing.processing_functions import format_frameshift, format_stop_codon - -test_config_file = "tests/test_config.yaml" - - -def get_test_cases(config: Config) -> list[TestCase]: - factory_custom = ProcessedEntryFactory(all_metadata_fields=list(config.processing_spec.keys())) - return [ - TestCase( - name="missing_required_fields", - input=UnprocessedEntryFactory.create_unprocessed_entry( - metadata_dict={ - "submissionId": "missing_required_fields", - } - ), - expected_output=factory_custom.create_processed_entry( - metadata_dict={ - "concatenated_string": "LOC_0.1", - }, - metadata_errors=[ - ("name_required", "Metadata field name_required is required."), - ( - "required_collection_date", - "Metadata field required_collection_date is required.", - ), - ], - ), - ), - TestCase( - name="missing_one_required_field", - input=UnprocessedEntryFactory.create_unprocessed_entry( - metadata_dict={ - "submissionId": "missing_one_required_field", - "name_required": "name", - } - ), - expected_output=factory_custom.create_processed_entry( - metadata_dict={ - "name_required": "name", - "concatenated_string": "LOC_1.1", - }, - metadata_errors=[ - ( - "required_collection_date", - "Metadata field required_collection_date is required.", - ), - ], - ), - ), - TestCase( - name="invalid_option", - input=UnprocessedEntryFactory.create_unprocessed_entry( - metadata_dict={ - "submissionId": "invalid_option", - "continent": "Afrika", - "name_required": "name", - "required_collection_date": "2022-11-01", - } - ), - expected_output=factory_custom.create_processed_entry( - metadata_dict={ - "name_required": "name", - "required_collection_date": "2022-11-01", - "concatenated_string": "Afrika/LOC_2.1/2022-11-01", - }, - metadata_errors=[ - ( - "continent", - "Metadata field continent:'Afrika' - not in list of accepted options.", - ), - ], - ), - ), - TestCase( - name="collection_date_in_future", - input=UnprocessedEntryFactory.create_unprocessed_entry( - metadata_dict={ - "submissionId": "collection_date_in_future", - "collection_date": "2088-12-01", - "name_required": "name", - "required_collection_date": "2022-11-01", - } - ), - expected_output=factory_custom.create_processed_entry( - metadata_dict={ - "collection_date": "2088-12-01", - "name_required": "name", - "required_collection_date": "2022-11-01", - "concatenated_string": "LOC_3.1/2022-11-01", - }, - metadata_errors=[ - ( - "collection_date", - "Metadata field collection_date:'2088-12-01' is in the future.", - ), - ], - ), - ), - TestCase( - name="invalid_collection_date", - input=UnprocessedEntryFactory.create_unprocessed_entry( - metadata_dict={ - "submissionId": "invalid_collection_date", - "collection_date": "01-02-2024", - "name_required": "name", - "required_collection_date": "2022-11-01", - } - ), - expected_output=factory_custom.create_processed_entry( - metadata_dict={ - "name_required": "name", - "required_collection_date": "2022-11-01", - "concatenated_string": "LOC_4.1/2022-11-01", - }, - metadata_errors=[ - ( - "collection_date", - "Metadata field collection_date: Date format is not recognized.", - ), - ], - ), - ), - TestCase( - name="invalid_timestamp", - input=UnprocessedEntryFactory.create_unprocessed_entry( - metadata_dict={ - "submissionId": "invalid_timestamp", - "sequenced_timestamp": " 2022-11-01Europe", - "name_required": "name", - "required_collection_date": "2022-11-01", - } - ), - expected_output=factory_custom.create_processed_entry( - metadata_dict={ - "name_required": "name", - "required_collection_date": "2022-11-01", - "concatenated_string": "LOC_5.1/2022-11-01", - }, - metadata_errors=[ - ( - "sequenced_timestamp", - "Timestamp is 2022-11-01Europe which is not in parseable YYYY-MM-DD. Parsing error: Unknown string format: 2022-11-01Europe", - ), - ], - ), - ), - TestCase( - name="date_only_year", - input=UnprocessedEntryFactory.create_unprocessed_entry( - metadata_dict={ - "submissionId": "date_only_year", - "collection_date": "2023", - "name_required": "name", - "required_collection_date": "2022-11-01", - } - ), - expected_output=factory_custom.create_processed_entry( - metadata_dict={ - "collection_date": "2023-01-01", - "name_required": "name", - "required_collection_date": "2022-11-01", - "concatenated_string": "LOC_6.1/2022-11-01", - }, - metadata_errors=[], - metadata_warnings=[ - ( - "collection_date", - "Metadata field collection_date:'2023' - Month and day are missing. Assuming January 1st.", - ), - ], - ), - ), - TestCase( - name="date_no_day", - input=UnprocessedEntryFactory.create_unprocessed_entry( - metadata_dict={ - "submissionId": "date_no_day", - "collection_date": "2023-12", - "name_required": "name", - "required_collection_date": "2022-11-01", - } - ), - expected_output=factory_custom.create_processed_entry( - metadata_dict={ - "collection_date": "2023-12-01", - "name_required": "name", - "required_collection_date": "2022-11-01", - "concatenated_string": "LOC_7.1/2022-11-01", - }, - metadata_errors=[], - metadata_warnings=[ - ( - "collection_date", - "Metadata field collection_date:'2023-12' - Day is missing. Assuming the 1st.", - ), - ], - ), - ), - TestCase( - name="invalid_int", - input=UnprocessedEntryFactory.create_unprocessed_entry( - metadata_dict={ - "submissionId": "invalid_int", - "age_int": "asdf", - "name_required": "name", - "required_collection_date": "2022-11-01", - } - ), - expected_output=factory_custom.create_processed_entry( - metadata_dict={ - "name_required": "name", - "required_collection_date": "2022-11-01", - "concatenated_string": "LOC_8.1/2022-11-01", - }, - metadata_errors=[ - ("age_int", "Invalid int value: asdf for field age_int."), - ], - ), - ), - TestCase( - name="invalid_float", - input=UnprocessedEntryFactory.create_unprocessed_entry( - metadata_dict={ - "submissionId": "invalid_float", - "percentage_float": "asdf", - "name_required": "name", - "required_collection_date": "2022-11-01", - } - ), - expected_output=factory_custom.create_processed_entry( - metadata_dict={ - "name_required": "name", - "required_collection_date": "2022-11-01", - "concatenated_string": "LOC_9.1/2022-11-01", - }, - metadata_errors=[ - ("percentage_float", "Invalid float value: asdf for field percentage_float."), - ], - ), - ), - TestCase( - name="invalid_date", - input=UnprocessedEntryFactory.create_unprocessed_entry( - metadata_dict={ - "submissionId": "invalid_date", - "name_required": "name", - "other_date": "01-02-2024", - "required_collection_date": "2022-11-01", - } - ), - expected_output=factory_custom.create_processed_entry( - metadata_dict={ - "name_required": "name", - "required_collection_date": "2022-11-01", - "concatenated_string": "LOC_10.1/2022-11-01", - }, - metadata_errors=[ - ( - "other_date", - "Date is 01-02-2024 which is not in the required format YYYY-MM-DD. Parsing error: time data '01-02-2024' does not match format '%Y-%m-%d'", - ), - ], - ), - ), - TestCase( - name="invalid_boolean", - input=UnprocessedEntryFactory.create_unprocessed_entry( - metadata_dict={ - "submissionId": "invalid_boolean", - "name_required": "name", - "is_lab_host_bool": "maybe", - "required_collection_date": "2022-11-01", - } - ), - expected_output=factory_custom.create_processed_entry( - metadata_dict={ - "name_required": "name", - "required_collection_date": "2022-11-01", - "concatenated_string": "LOC_11.1/2022-11-01", - }, - metadata_errors=[ - ( - "is_lab_host_bool", - "Invalid boolean value: maybe for field is_lab_host_bool.", - ), - ], - ), - ), - ] - - -def sort_annotations(annotations: list[ProcessingAnnotation]): - return sorted(annotations, key=lambda x: (x.source[0].name, x.message)) - - -class PreprocessingTests(unittest.TestCase): - def test_process_all(self) -> None: - config: Config = get_config(test_config_file) - test_cases = get_test_cases(config=config) - for test_case in test_cases: - dataset_dir = "temp" # This is not used as we do not align sequences - result: list[ProcessedEntry] = process_all([test_case.input], dataset_dir, config) - processed_entry = result[0] - if ( - processed_entry.accession != test_case.expected_output.accession - or processed_entry.version != test_case.expected_output.version - ): - message = ( - f"{test_case.name}: processed entry accessionVersion {processed_entry.accession}" - f".{processed_entry.version} does not match expected output " - f"{test_case.expected_output.accession}.{test_case.expected_output.version}." - ) - raise AssertionError(message) - if processed_entry.data != test_case.expected_output.data: - message = ( - f"{test_case.name}: processed metadata {processed_entry.data} does not" - f" match expected output {test_case.expected_output.data}." - ) - raise AssertionError(message) - if sort_annotations(processed_entry.errors) != sort_annotations( - test_case.expected_output.errors - ): - message = ( - f"{test_case.name}: processed errors: {processed_entry.errors} does not " - f"match expected output: {test_case.expected_output.errors}." - ) - raise AssertionError(message) - if sort_annotations(processed_entry.warnings) != sort_annotations( - test_case.expected_output.warnings - ): - message = ( - f"{test_case.name}: processed warnings {processed_entry.warnings} does not" - f" match expected output {test_case.expected_output.warnings}." - ) - raise AssertionError(message) - - def test_format_frameshift(self): - # Test case 1: Empty input - self.assertEqual(format_frameshift("[]"), "") - - # Test case 2: Single frameshift - input_single = '[{"cdsName": "GPC", "nucRel": {"begin": 5, "end": 20}, "nucAbs": [{"begin": 97, "end": 112}], "codon": {"begin": 2, "end": 7}, "gapsLeading": {"begin": 1, "end": 2}, "gapsTrailing": {"begin": 7, "end": 8}}]' - expected_single = "GPC:3-7(nt:98-112)" - self.assertEqual(format_frameshift(input_single), expected_single) - - # Test case 3: Multiple frameshifts - input_multiple = '[{"cdsName": "GPC", "nucRel": {"begin": 5, "end": 20}, "nucAbs": [{"begin": 97, "end": 112}], "codon": {"begin": 2, "end": 7}, "gapsLeading": {"begin": 1, "end": 2}, "gapsTrailing": {"begin": 7, "end": 8}}, {"cdsName": "NP", "nucRel": {"begin": 10, "end": 15}, "nucAbs": [{"begin": 200, "end": 205}], "codon": {"begin": 3, "end": 5}, "gapsLeading": {"begin": 2, "end": 3}, "gapsTrailing": {"begin": 5, "end": 6}}]' - expected_multiple = "GPC:3-7(nt:98-112),NP:4-5(nt:201-205)" - self.assertEqual(format_frameshift(input_multiple), expected_multiple) - - # Test case 4: Single nucleotide frameshift - input_single_nuc = '[{"cdsName": "L", "nucRel": {"begin": 30, "end": 31}, "nucAbs": [{"begin": 500, "end": 501}], "codon": {"begin": 10, "end": 11}, "gapsLeading": {"begin": 9, "end": 10}, "gapsTrailing": {"begin": 11, "end": 12}}]' - expected_single_nuc = "L:11(nt:501)" - self.assertEqual(format_frameshift(input_single_nuc), expected_single_nuc) - - def test_format_stop_codon(self): - # Test case 1: Empty input - self.assertEqual(format_stop_codon("[]"), "") - - # Test case 2: Single stop codon - input_single = '[{"cdsName": "GPC", "codon": 123}]' - expected_single = "GPC:124" - self.assertEqual(format_stop_codon(input_single), expected_single) - - # Test case 3: Multiple stop codons - input_multiple = '[{"cdsName": "GPC", "codon": 123}, {"cdsName": "NP", "codon": 456}]' - expected_multiple = "GPC:124,NP:457" - self.assertEqual(format_stop_codon(input_multiple), expected_multiple) - - # Test case 4: Stop codon at position 0 - input_zero = '[{"cdsName": "L", "codon": 0}]' - expected_zero = "L:1" - self.assertEqual(format_stop_codon(input_zero), expected_zero) - - -if __name__ == "__main__": - unittest.main() diff --git a/preprocessing/nextclade/tests/test_processing_functions.py b/preprocessing/nextclade/tests/test_processing_functions.py new file mode 100644 index 000000000..985ecc63d --- /dev/null +++ b/preprocessing/nextclade/tests/test_processing_functions.py @@ -0,0 +1,389 @@ +# ruff: noqa: S101 +from dataclasses import dataclass + +import pytest +from factory_methods import ProcessedEntryFactory, ProcessingTestCase, UnprocessedEntryFactory + +from loculus_preprocessing.config import Config, get_config +from loculus_preprocessing.datatypes import ProcessedEntry, ProcessingAnnotation +from loculus_preprocessing.prepro import process_all +from loculus_preprocessing.processing_functions import format_frameshift, format_stop_codon + +# Config file used for testing +test_config_file = "tests/test_config.yaml" + + +@dataclass +class Case: + name: str + metadata: dict[str, str] + expected_metadata: dict[str, str] + expected_errors: list[tuple[str, str]] + expected_warnings: list[tuple[str, str]] = None + accession_id: str = "000999" + + def create_test_case(self, factory_custom: ProcessedEntryFactory) -> ProcessingTestCase: + unprocessed_entry = UnprocessedEntryFactory.create_unprocessed_entry( + metadata_dict=self.metadata, + accession_id=self.accession_id, + ) + expected_output = factory_custom.create_processed_entry( + metadata_dict=self.expected_metadata, + accession=unprocessed_entry.accessionVersion.split(".")[0], + metadata_errors=self.expected_errors, + metadata_warnings=self.expected_warnings or [], + ) + return ProcessingTestCase( + name=self.name, input=unprocessed_entry, expected_output=expected_output + ) + + +test_case_definitions = [ + Case( + name="missing_required_fields", + metadata={"submissionId": "missing_required_fields"}, + accession_id="0", + expected_metadata={"concatenated_string": "LOC_0.1"}, + expected_errors=[ + ("name_required", "Metadata field name_required is required."), + ( + "required_collection_date", + "Metadata field required_collection_date is required.", + ), + ], + ), + Case( + name="missing_one_required_field", + metadata={"submissionId": "missing_one_required_field", "name_required": "name"}, + accession_id="1", + expected_metadata={"name_required": "name", "concatenated_string": "LOC_1.1"}, + expected_errors=[ + ( + "required_collection_date", + "Metadata field required_collection_date is required.", + ), + ], + ), + Case( + name="invalid_option", + metadata={ + "submissionId": "invalid_option", + "continent": "Afrika", + "name_required": "name", + "required_collection_date": "2022-11-01", + }, + accession_id="2", + expected_metadata={ + "name_required": "name", + "required_collection_date": "2022-11-01", + "concatenated_string": "Afrika/LOC_2.1/2022-11-01", + }, + expected_errors=[ + ( + "continent", + "Metadata field continent:'Afrika' - not in list of accepted options.", + ), + ], + ), + Case( + name="collection_date_in_future", + metadata={ + "submissionId": "collection_date_in_future", + "collection_date": "2088-12-01", + "name_required": "name", + "required_collection_date": "2022-11-01", + }, + accession_id="3", + expected_metadata={ + "collection_date": "2088-12-01", + "name_required": "name", + "required_collection_date": "2022-11-01", + "concatenated_string": "LOC_3.1/2022-11-01", + }, + expected_errors=[ + ( + "collection_date", + "Metadata field collection_date:'2088-12-01' is in the future.", + ), + ], + ), + Case( + name="invalid_collection_date", + metadata={ + "submissionId": "invalid_collection_date", + "collection_date": "01-02-2024", + "name_required": "name", + "required_collection_date": "2022-11-01", + }, + accession_id="4", + expected_metadata={ + "name_required": "name", + "required_collection_date": "2022-11-01", + "concatenated_string": "LOC_4.1/2022-11-01", + }, + expected_errors=[ + ( + "collection_date", + "Metadata field collection_date: Date format is not recognized.", + ), + ], + ), + Case( + name="invalid_timestamp", + metadata={ + "submissionId": "invalid_timestamp", + "sequenced_timestamp": " 2022-11-01Europe", + "name_required": "name", + "required_collection_date": "2022-11-01", + }, + accession_id="5", + expected_metadata={ + "name_required": "name", + "required_collection_date": "2022-11-01", + "concatenated_string": "LOC_5.1/2022-11-01", + }, + expected_errors=[ + ( + "sequenced_timestamp", + ( + "Timestamp is 2022-11-01Europe which is not in parseable YYYY-MM-DD. " + "Parsing error: Unknown string format: 2022-11-01Europe" + ), + ), + ], + ), + Case( + name="date_only_year", + metadata={ + "submissionId": "date_only_year", + "collection_date": "2023", + "name_required": "name", + "required_collection_date": "2022-11-01", + }, + accession_id="6", + expected_metadata={ + "collection_date": "2023-01-01", + "name_required": "name", + "required_collection_date": "2022-11-01", + "concatenated_string": "LOC_6.1/2022-11-01", + }, + expected_errors=[], + expected_warnings=[ + ( + "collection_date", + ( + "Metadata field collection_date:'2023' - Month and day are missing. " + "Assuming January 1st." + ), + ), + ], + ), + Case( + name="date_no_day", + metadata={ + "submissionId": "date_no_day", + "collection_date": "2023-12", + "name_required": "name", + "required_collection_date": "2022-11-01", + }, + accession_id="7", + expected_metadata={ + "collection_date": "2023-12-01", + "name_required": "name", + "required_collection_date": "2022-11-01", + "concatenated_string": "LOC_7.1/2022-11-01", + }, + expected_errors=[], + expected_warnings=[ + ( + "collection_date", + "Metadata field collection_date:'2023-12' - Day is missing. Assuming the 1st.", + ), + ], + ), + Case( + name="invalid_int", + metadata={ + "submissionId": "invalid_int", + "age_int": "asdf", + "name_required": "name", + "required_collection_date": "2022-11-01", + }, + accession_id="8", + expected_metadata={ + "name_required": "name", + "required_collection_date": "2022-11-01", + "concatenated_string": "LOC_8.1/2022-11-01", + }, + expected_errors=[ + ("age_int", "Invalid int value: asdf for field age_int."), + ], + ), + Case( + name="invalid_float", + metadata={ + "submissionId": "invalid_float", + "percentage_float": "asdf", + "name_required": "name", + "required_collection_date": "2022-11-01", + }, + accession_id="9", + expected_metadata={ + "name_required": "name", + "required_collection_date": "2022-11-01", + "concatenated_string": "LOC_9.1/2022-11-01", + }, + expected_errors=[ + ("percentage_float", "Invalid float value: asdf for field percentage_float."), + ], + ), + Case( + name="invalid_date", + metadata={ + "submissionId": "invalid_date", + "name_required": "name", + "other_date": "01-02-2024", + "required_collection_date": "2022-11-01", + }, + accession_id="10", + expected_metadata={ + "name_required": "name", + "required_collection_date": "2022-11-01", + "concatenated_string": "LOC_10.1/2022-11-01", + }, + expected_errors=[ + ( + "other_date", + ( + "Date is 01-02-2024 which is not in the required format YYYY-MM-DD. " + "Parsing error: time data '01-02-2024' does not match format '%Y-%m-%d'" + ), + ), + ], + ), + Case( + name="invalid_boolean", + metadata={ + "submissionId": "invalid_boolean", + "name_required": "name", + "is_lab_host_bool": "maybe", + "required_collection_date": "2022-11-01", + }, + accession_id="11", + expected_metadata={ + "name_required": "name", + "required_collection_date": "2022-11-01", + "concatenated_string": "LOC_11.1/2022-11-01", + }, + expected_errors=[ + ("is_lab_host_bool", "Invalid boolean value: maybe for field is_lab_host_bool."), + ], + ), +] + + +@pytest.fixture(scope="module") +def config(): + return get_config(test_config_file) + + +@pytest.fixture(scope="module") +def factory_custom(config): + return ProcessedEntryFactory(all_metadata_fields=list(config.processing_spec.keys())) + + +def sort_annotations(annotations: list[ProcessingAnnotation]) -> list[ProcessingAnnotation]: + return sorted(annotations, key=lambda x: (x.source[0].name, x.message)) + + +def process_single_entry(test_case: ProcessingTestCase, config: Config) -> ProcessedEntry: + dataset_dir = "temp" # This is not used as we do not align sequences + result = process_all([test_case.input], dataset_dir, config) + return result[0] + + +def verify_processed_entry( + processed_entry: ProcessedEntry, expected_output: ProcessedEntry, test_name: str +): + # Check accession and version + assert ( + processed_entry.accession == expected_output.accession + and processed_entry.version == expected_output.version + ), ( + f"{test_name}: processed entry accessionVersion " + f"{processed_entry.accession}.{processed_entry.version} " + f"does not match expected output {expected_output.accession}.{expected_output.version}." + ) + + # Check metadata + assert processed_entry.data.metadata == expected_output.data.metadata, ( + f"{test_name}: processed metadata {processed_entry.data.metadata} " + f"does not match expected metadata {expected_output.data.metadata}." + ) + + # Check errors + processed_errors = sort_annotations(processed_entry.errors) + expected_errors = sort_annotations(expected_output.errors) + assert processed_errors == expected_errors, ( + f"{test_name}: processed errors: {processed_errors}", + f"does not match expected output: {expected_errors}.", + ) + + # Check warnings + processed_warnings = sort_annotations(processed_entry.warnings) + expected_warnings = sort_annotations(expected_output.warnings) + assert processed_warnings == expected_warnings, ( + f"{test_name}: processed warnings {processed_warnings}" + f"does not match expected output {expected_warnings}." + ) + + +@pytest.mark.parametrize("test_case_def", test_case_definitions, ids=lambda tc: tc.name) +def test_preprocessing(test_case_def: Case, config: Config, factory_custom: ProcessedEntryFactory): + test_case = test_case_def.create_test_case(factory_custom) + processed_entry = process_single_entry(test_case, config) + verify_processed_entry(processed_entry, test_case.expected_output, test_case.name) + + +def test_format_frameshift(): + # Test case 1: Empty input + assert not format_frameshift("[]") + + # Test case 2: Single frameshift + input_single = '[{"cdsName": "GPC", "nucRel": {"begin": 5, "end": 20}, "nucAbs": [{"begin": 97, "end": 112}], "codon": {"begin": 2, "end": 7}, "gapsLeading": {"begin": 1, "end": 2}, "gapsTrailing": {"begin": 7, "end": 8}}]' # noqa: E501 + expected_single = "GPC:3-7(nt:98-112)" + assert format_frameshift(input_single) == expected_single + + # Test case 3: Multiple frameshifts + input_multiple = '[{"cdsName": "GPC", "nucRel": {"begin": 5, "end": 20}, "nucAbs": [{"begin": 97, "end": 112}], "codon": {"begin": 2, "end": 7}, "gapsLeading": {"begin": 1, "end": 2}, "gapsTrailing": {"begin": 7, "end": 8}}, {"cdsName": "NP", "nucRel": {"begin": 10, "end": 15}, "nucAbs": [{"begin": 200, "end": 205}], "codon": {"begin": 3, "end": 5}, "gapsLeading": {"begin": 2, "end": 3}, "gapsTrailing": {"begin": 5, "end": 6}}]' # noqa: E501 + expected_multiple = "GPC:3-7(nt:98-112),NP:4-5(nt:201-205)" + assert format_frameshift(input_multiple) == expected_multiple + + # Test case 4: Single nucleotide frameshift + input_single_nuc = '[{"cdsName": "L", "nucRel": {"begin": 30, "end": 31}, "nucAbs": [{"begin": 500, "end": 501}], "codon": {"begin": 10, "end": 11}, "gapsLeading": {"begin": 9, "end": 10}, "gapsTrailing": {"begin": 11, "end": 12}}]' # noqa: E501 + expected_single_nuc = "L:11(nt:501)" + assert format_frameshift(input_single_nuc) == expected_single_nuc + + +def test_format_stop_codon(): + # Test case 1: Empty input + assert not format_stop_codon("[]") + + # Test case 2: Single stop codon + input_single = '[{"cdsName": "GPC", "codon": 123}]' + expected_single = "GPC:124" + assert format_stop_codon(input_single) == expected_single + + # Test case 3: Multiple stop codons + input_multiple = '[{"cdsName": "GPC", "codon": 123}, {"cdsName": "NP", "codon": 456}]' + expected_multiple = "GPC:124,NP:457" + assert format_stop_codon(input_multiple) == expected_multiple + + # Test case 4: Stop codon at position 0 + input_zero = '[{"cdsName": "L", "codon": 0}]' + expected_zero = "L:1" + assert format_stop_codon(input_zero) == expected_zero + + +if __name__ == "__main__": + pytest.main()