diff --git a/Makefile b/Makefile index 9d8a0c0..5e75969 100644 --- a/Makefile +++ b/Makefile @@ -39,5 +39,5 @@ build-airflow-image: generate-requirements ## build local airflow image for the -f Dockerfile \ --no-cache -upload-gwas-catalog-bucket-readme: ## Upload gwas_catalog_bucket readme to the bucket. - @gsutil cp docs/datasources/gwas_catalog_data/README.md gs://gwas_catalog_data/README.md +upload-ukb-ppp-bucket-readme: ## Upload ukb_ppp_eur_data readme to the bucket + @gsutil rsync docs/datasources/ukb_ppp_eur_data gs://ukb_ppp_eur_data/docs diff --git a/docs/README.md b/docs/README.md index fb9e595..771c433 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1 +1,13 @@ +# Orchestration documentation + This catalog describes how the orchestration works in the current state + +### How to generate dag svg files + +1. Locate your global `airflow.cfg` file and update the [core] dag_folder in `airflow.cfg` to point to the `src` directory of the orchestration repository or set the `AIRFLOW__CORE__DAGS_FOLDER` environment variable. + +2. Run + + ```bash + poetry run airflow dags show --save docs/${DAG_NAME}.svg ${DAG_NAME} + ``` diff --git a/docs/datasources/ukb_ppp_eur_data/README.md b/docs/datasources/ukb_ppp_eur_data/README.md new file mode 100644 index 0000000..ea85ac1 --- /dev/null +++ b/docs/datasources/ukb_ppp_eur_data/README.md @@ -0,0 +1,119 @@ +# UK Biobank Pharma Proteomics Project (UKB-PPP) + +This document was updated on 2024-10-11 + +Data source comes from the `https://registry.opendata.aws/ukbppp/` + +Data stored under `gs://ukb_ppp_eur_data` bucket comes with following structure + +``` +gs://ukb_ppp_eur_data/credible_set_datasets/susie +gs://ukb_ppp_eur_data/docs/ +gs://ukb_ppp_eur_data/finemapping_logs/ +gs://ukb_ppp_eur_data/finemapping_manifests/ +gs://ukb_ppp_eur_data/harmonised_summary_statistics/ +gs://ukb_ppp_eur_data/study_index/ +gs://ukb_ppp_eur_data/study_locus_lb_clumped/ +gs://ukb_ppp_eur_data/test/ +``` + +## Processing description + +## Pre-steps + +Full description of the process can be found in https://github.com/opentargets/issues/issues/3234 + +### 1. Mirror + +- **Input.** Original data is hosted on Synapse. +- **Transformation.** + - As we decided in the past, we want to keep the copy of the original data in the Vault. + - Protocol is available here: https://github.com/opentargets/gentropy-vault/blob/main/datasets/ukb-ppp.md. + - The protocol must be run manually. +- **Output.** The output of this step is kept forever in the Vault. + +### 2. Preprocess + +- **Input.** The mirrored data from the previous step. +- **Transformation.** + - The data which we mirrored during the previous steps has to undergo several specific transformations which aren't achievable in Spark (especially the first one): + - Extract gzipped per-chromosome files from inside the individual TAR files, decompress, partition by chromosome + - Recreate the study ID. This is required because multiple rows in the study index can reuse the same summary stats file + - Drop certain rows which don't have a corresponding summary stats file + - This transformation is done using Google Batch. The code can be found in this new repository: https://github.com/opentargets/gentropy-input-support. The UKB PPP-specific part is this one: https://github.com/opentargets/gentropy-input-support/blob/dc5f8f7aee13a066933f3fd5b18a9b3a5ca71069/data_sources.py#L43-L103. + - The command to run is `./submit.py ukb_ppp_eur` inside the `gentropy-input-support` + - This step must also be triggered manually, how to do this is described in the repository. +- **Output.** Precursors of study index and summary stats datasets are output. Because we decided that we don't want to keep the data twice, the output of this step is only kept temporarily and is deleted after 60 days according to the _gs://gentropy-tmp_ bucket lifecycle rules. + +## Orchestration steps + +### ukb_ppp_eur_harmonisation dag + +**Harmonisation dag** contains two steps: + +- raw sumstat preprocessing (ukb_ppp_eur_sumstat_preprocess) +- locus breaker clumping (locus_breaker_clumping) + +The configuration of the dataproc infrastructure and individual step parameters can be found in `ukb_ppp_eur_harmonisation.yaml` file. + +#### ukb_ppp_eur_sumstat_preprocess + +![harmonisation dag](ukb_ppp_eur_harmonisation.svg) + +This process **harmonizes the raw pre-processed data** to the [SummaryStatistics](https://opentargets.github.io/gentropy/python_api/datasets/summary_statistics/) and creates the [StudyIndex](https://opentargets.github.io/gentropy/python_api/datasets/study_index/). +The process is runs on **dataproc** cluster. + +The outputs are stored in: + +- `gs://ukb_ppp_eur_data/study_index` - study index +- `gs://ukb_ppp_eur_data/harmonised_summary_statistics` - summary statistics + +#### locus_breaker_clumping + +This process performs locus clumping on previously harmonised summary statistics and results in [StudyLocus](https://opentargets.github.io/gentropy/python_api/datasets/study_locus/) dataset stored under `gs://ukb_ppp_eur_data/study_locus_lb_clumped`. + +#### Parametrization of dataproc preprocessing jobs + +To parametrize the dataproc cluster one need to update the logic inside the `dataproc` block in `ukb_ppp_eur_harmonisation.yaml` file. + +### ukb_ppp_eur_finemapping dag + +![finemapping dag](ukb_ppp_eur_finemapping.svg) + +This dag performs fine mapping with SuSiE-inf algorithm on clumped study loci to obtain [Credible sets](https://opentargets.github.io/gentropy/python_api/datasets/study_locus/). This is expensive process and is run on google batch. + +Due to infrastructure, the fine mapping process is divided into a 3-step logic: + +- [x] Generate manifests - `FinemappingBatchJobManifestOperator` +- [x] Execute Finemapping batch job (finemapping step per each manifest) - `FinemappingBatchOperator` +- [x] Collect finemapping logs + +![finemapping](finemapping.svg) + +1. Tasks performed by `FinemappingBatchJobManifestOperator` + +- Collect all individual loci parquet files +- Partition collected loci into batches with with `max_records_per_chunk` as a limit of the batch size. +- For each batch create a manifest file that will be imputed to the fine mapping gentropy step +- Save the batch manifests to google cloud storage. + +2. Tasks performed by `FinemappingBatchOperator` + +- Execute one google batch job per manifest with `n <= max_records_per_chunk` tasks. +- Each task executes finemapping step on single `StudyLocus` record. + +3. Collect logs + +The output of finemapping can be found under the: + +- `gs://ukb_ppp_eur_data/credible_set_datasets/susie/` - fine mapped study loci +- `gs://ukb_ppp_eur_data/finemapping_manifests/` - manifests used during the fine mapping job +- `gs://ukb_ppp_eur_data/finemapping_logs/` - logs from the individual finemapping tasks + +#### Parametrization of google batch finemapping job + +The configuration of the google batch infrastructure and individual step parameters can be found in `ukb_ppp_eur_finemapping.yaml` file. +To adjust the parameters for google batch infrastructure refer to the `google_batch` block in the node configuration. + +> [!WARNING] +> After running the google batch fine mapping job, ensure that the job tasks have succeeded, otherwise the job requires manual curation. diff --git a/docs/datasources/ukb_ppp_eur_data/finemapping.svg b/docs/datasources/ukb_ppp_eur_data/finemapping.svg new file mode 100644 index 0000000..8174201 --- /dev/null +++ b/docs/datasources/ukb_ppp_eur_data/finemapping.svg @@ -0,0 +1,13 @@ + + + + + + + + StudyLocusId=1/file.parquetStudyLocusId=2/file.parquetStudyLocusId=3/file.parquetStudyLocusId=4/file.parquetStudyLocusId=5/file.parquetCollected Loci from Clumped StudyLocus datasetmax_records_per_chunk=3StudyLocusId=2/file.parquetStudyLocusId=1/file.parquetStudyLocusId=3/file.parquetStudyLocusId=4/file.parquetStudyLocusId=5/file.parquetchunk_1chunk_0ManifestsGoogle Cloud Storagemanifests/chunk_0.csvmanifests/chunk_1.csvBatch job 1Finemapping(chunk_1[0])Finemapping(chunk_1[1])Batch job 0Finemapping(chunk_0[0])Finemapping(chunk_0[1])Finemapping(chunk_0[2]) diff --git a/docs/datasources/ukb_ppp_eur_data/ukb_ppp_eur_finemapping.svg b/docs/datasources/ukb_ppp_eur_data/ukb_ppp_eur_finemapping.svg new file mode 100644 index 0000000..91d91d0 --- /dev/null +++ b/docs/datasources/ukb_ppp_eur_data/ukb_ppp_eur_finemapping.svg @@ -0,0 +1,44 @@ + + + + + + +ukb_ppp_eur_finemapping + +ukb_ppp_eur_finemapping + + +finemapping_batch_job + +finemapping_batch_job + + + +move_finemapping_logs + +move_finemapping_logs + + + +finemapping_batch_job->move_finemapping_logs + + + + + +generate_manifests + +generate_manifests + + + +generate_manifests->finemapping_batch_job + + + + + diff --git a/docs/datasources/ukb_ppp_eur_data/ukb_ppp_eur_harmonisation.svg b/docs/datasources/ukb_ppp_eur_data/ukb_ppp_eur_harmonisation.svg new file mode 100644 index 0000000..3b930d6 --- /dev/null +++ b/docs/datasources/ukb_ppp_eur_data/ukb_ppp_eur_harmonisation.svg @@ -0,0 +1,56 @@ + + + + + + +ukb_ppp_eur_harmonisation + +ukb_ppp_eur_harmonisation + + +create_cluster + +create_cluster + + + +ukb_ppp_eur_sumstat_preprocess + +ukb_ppp_eur_sumstat_preprocess + + + +create_cluster->ukb_ppp_eur_sumstat_preprocess + + + + + +delete_cluster + +delete_cluster + + + +locus_breaker_clumping + +locus_breaker_clumping + + + +locus_breaker_clumping->delete_cluster + + + + + +ukb_ppp_eur_sumstat_preprocess->locus_breaker_clumping + + + + + diff --git a/poetry.lock b/poetry.lock index 7ca5425..5f70b0b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3624,6 +3624,22 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4 [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] +[[package]] +name = "graphviz" +version = "0.20.3" +description = "Simple Python interface for Graphviz" +optional = false +python-versions = ">=3.8" +files = [ + {file = "graphviz-0.20.3-py3-none-any.whl", hash = "sha256:81f848f2904515d8cd359cc611faba817598d2feaac4027b266aa3eda7b3dde5"}, + {file = "graphviz-0.20.3.zip", hash = "sha256:09d6bc81e6a9fa392e7ba52135a9d49f1ed62526f96499325930e87ca1b5925d"}, +] + +[package.extras] +dev = ["flake8", "pep8-naming", "tox (>=3)", "twine", "wheel"] +docs = ["sphinx (>=5,<7)", "sphinx-autodoc-typehints", "sphinx-rtd-theme"] +test = ["coverage", "pytest (>=7,<8.1)", "pytest-cov", "pytest-mock (>=3)"] + [[package]] name = "greenlet" version = "3.1.1" @@ -7768,4 +7784,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10, <3.11" -content-hash = "e6072a728bc377fe778d98dd6f05f937f2cd006026dcf76490ec74b3471bae58" +content-hash = "870b50a07e5f4ef2e53c5e4e214308ff23994054dabb586a6d375df59e94c831" diff --git a/pyproject.toml b/pyproject.toml index 2e6a084..5da0058 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ pre-commit = "^3.7.1" coverage = "^7.5.4" psycopg2-binary = "^2.9.9" interrogate = "^1.7.0" +graphviz = "^0.20.3" [tool.poetry.group.test.dependencies] diff --git a/src/ot_orchestration/dags/config/genetics_etl.yaml b/src/ot_orchestration/dags/config/genetics_etl.yaml index 39def92..2184136 100644 --- a/src/ot_orchestration/dags/config/genetics_etl.yaml +++ b/src/ot_orchestration/dags/config/genetics_etl.yaml @@ -1,6 +1,6 @@ gwas_catalog_manifests_path: gs://gwas_catalog_data/manifests l2g_gold_standard_path: gs://genetics_etl_python_playground/input/l2g/gold_standard/curation.json -release_dir: gs://ot_orchestration/releases/27.09 +release_dir: gs://ot_orchestration/releases/24.10.11 dataproc: python_main_module: gs://genetics_etl_python_playground/initialisation/gentropy/szsz-fix-vep-vi-schema-ordering/cli.py cluster_metadata: @@ -16,7 +16,7 @@ nodes: step: biosample_index step.cell_ontology_input_path: gs://open-targets-pre-data-releases/24.06dev-test/input/biosamples/cl.json step.uberon_input_path: gs://open-targets-pre-data-releases/24.06dev-test/input/biosamples/uberon.json - step.biosample_index_path: gs://ot_orchestration/releases/27.09/biosample_index + step.biosample_index_path: gs://ot_orchestration/releases/24.10.11/biosample_index - id: study_validation kind: Task prerequisites: @@ -30,9 +30,9 @@ nodes: - gs://finngen_data/r11/study_index step.target_index_path: gs://genetics_etl_python_playground/releases/24.06/gene_index step.disease_index_path: gs://open-targets-pre-data-releases/24.06/output/etl/parquet/diseases - step.valid_study_index_path: &valid_study_index gs://ot_orchestration/releases/27.09/study_index - step.invalid_study_index_path: gs://ot_orchestration/releases/27.09/invalid_study_index - step.biosample_index_path: gs://ot_orchestration/releases/27.09/biosample_index + step.valid_study_index_path: &valid_study_index gs://ot_orchestration/releases/24.10.11/study_index + step.invalid_study_index_path: gs://ot_orchestration/releases/24.10.11/invalid_study_index + step.biosample_index_path: gs://ot_orchestration/releases/24.10.11/biosample_index step.invalid_qc_reasons: - UNRESOLVED_TARGET - UNRESOLVED_DISEASE @@ -53,10 +53,10 @@ nodes: - gs://gwas_catalog_data/credible_set_datasets/gwas_catalog_PICSed_summary_statistics - gs://gwas_catalog_data/credible_set_datasets/gwas_catalog_susie_summary_statistics - gs://eqtl_catalogue_data/credible_set_datasets/eqtl_catalogue_susie/ - - gs://ukb_ppp_eur_data/credible_set_datasets/ukb_ppp_eur_susie + - gs://ukb_ppp_eur_data/credible_set_datasets/susie - gs://finngen_data/r11/credible_set_datasets/finngen_r11_susie - step.valid_study_locus_path: &valid_credible_set gs://ot_orchestration/releases/27.09/credible_set - step.invalid_study_locus_path: gs://ot_orchestration/releases/27.09/invalid_credible_set + step.valid_study_locus_path: &valid_credible_set gs://ot_orchestration/releases/24.10.11/credible_set + step.invalid_study_locus_path: gs://ot_orchestration/releases/24.10.11/invalid_credible_set step.invalid_qc_reasons: - DUPLICATED_STUDYLOCUS_ID - AMBIGUOUS_STUDY @@ -97,7 +97,7 @@ nodes: SOURCE_PATH: gs://open-targets-pre-data-releases/24.09/input/pharmacogenomics-inputs/pharmacogenomics.json.gz SOURCE_FORMAT: json - SOURCE_NAME: gentropy_credible_sets - SOURCE_PATH: gs://ot_orchestration/releases/27.09/credible_set + SOURCE_PATH: gs://ot_orchestration/releases/24.10.11/credible_set SOURCE_FORMAT: parquet resource_specs: cpu_milli: 2000 @@ -112,7 +112,7 @@ nodes: step: variant_to_vcf step.source_path: $SOURCE_PATH step.source_format: $SOURCE_FORMAT - step.vcf_path: gs://ot_orchestration/releases/27.09/variants/raw_variants/$SOURCE_NAME + step.vcf_path: gs://ot_orchestration/releases/24.10.11/variants/raw_variants/$SOURCE_NAME step.session.write_mode: overwrite +step.session.extended_spark_conf: "{spark.jars:https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-hadoop3-latest.jar}" @@ -120,8 +120,8 @@ nodes: kind: Task params: # the extension after saving from pyspark csv is going to be the .csv - input_vcf_glob: gs://ot_orchestration/releases/27.09/variants/raw_variants/**.csv - output_path: gs://ot_orchestration/releases/27.09/variants/merged_variants + input_vcf_glob: gs://ot_orchestration/releases/24.10.11/variants/raw_variants/**.csv + output_path: gs://ot_orchestration/releases/24.10.11/variants/merged_variants chunk_size: 2000 prerequisites: - variant_to_vcf @@ -141,8 +141,8 @@ nodes: machine_type: n1-standard-4 params: vep_cache_path: gs://genetics_etl_python_playground/vep/cache - vcf_input_path: gs://ot_orchestration/releases/27.09/variants/merged_variants - vep_output_path: gs://ot_orchestration/releases/27.09/variants/annotated_variants + vcf_input_path: gs://ot_orchestration/releases/24.10.11/variants/merged_variants + vep_output_path: gs://ot_orchestration/releases/24.10.11/variants/annotated_variants prerequisites: - list_nonannotated_vcfs prerequisites: @@ -151,9 +151,9 @@ nodes: command: gs://genetics_etl_python_playground/initialisation/0.0.0/cli.py params: step: variant_index - step.vep_output_json_path: gs://ot_orchestration/releases/27.09/variants/annotated_variants + step.vep_output_json_path: gs://ot_orchestration/releases/24.10.11/variants/annotated_variants step.gnomad_variant_annotations_path: gs://genetics_etl_python_playground/static_assets/gnomad_variants - step.variant_index_path: gs://ot_orchestration/releases/27.09/variant_index + step.variant_index_path: gs://ot_orchestration/releases/24.10.11/variant_index prerequisites: - variant_annotation - id: gene_index @@ -161,13 +161,13 @@ nodes: params: step: gene_index step.target_index: gs://genetics_etl_python_playground/static_assets/targets # OTP 23.12 data - step.gene_index_path: gs://ot_orchestration/releases/27.09/gene_index + step.gene_index_path: gs://ot_orchestration/releases/24.10.11/gene_index - id: variant_to_gene command: gs://genetics_etl_python_playground/initialisation/0.0.0/cli.py params: step: variant_to_gene - step.variant_index_path: gs://ot_orchestration/releases/27.09/variant_index - step.gene_index_path: gs://ot_orchestration/releases/27.09/gene_index + step.variant_index_path: gs://ot_orchestration/releases/24.10.11/variant_index + step.gene_index_path: gs://ot_orchestration/releases/24.10.11/gene_index step.vep_consequences_path: gs://genetics_etl_python_playground/static_assets/variant_consequence_to_score.tsv step.liftover_chain_file_path: gs://genetics_etl_python_playground/static_assets/grch37_to_grch38.over.chain step.interval_sources: @@ -175,7 +175,7 @@ nodes: - javierre: gs://genetics_etl_python_playground/static_assets/javierre_2016_preprocessed - jung: gs://genetics_etl_python_playground/static_assets/jung2019_pchic_tableS3.csv - thurman: gs://genetics_etl_python_playground/static_assets/thurman2012/genomewideCorrs_above0.7_promoterPlusMinus500kb_withGeneNames_32celltypeCategories.bed8.gz - step.v2g_path: gs://ot_orchestration/releases/27.09/variant_to_gene + step.v2g_path: gs://ot_orchestration/releases/24.10.11/variant_to_gene prerequisites: - variant_index - gene_index @@ -184,7 +184,7 @@ nodes: params: step: colocalisation step.credible_set_path: *valid_credible_set - step.coloc_path: gs://ot_orchestration/releases/27.09/colocalisation/ecaviar + step.coloc_path: gs://ot_orchestration/releases/24.10.11/colocalisation/ecaviar step.colocalisation_method: ECaviar prerequisites: - credible_set_validation @@ -193,7 +193,7 @@ nodes: params: step: colocalisation step.credible_set_path: *valid_credible_set - step.coloc_path: gs://ot_orchestration/releases/27.09/colocalisation/coloc + step.coloc_path: gs://ot_orchestration/releases/24.10.11/colocalisation/coloc step.colocalisation_method: Coloc prerequisites: - colocalisation_ecaviar # ensure two tasks are not draining each other resources @@ -204,13 +204,13 @@ nodes: step.run_mode: train step.wandb_run_name: "24.09" step.hf_hub_repo_id: opentargets/locus_to_gene - step.model_path: gs://ot_orchestration/releases/27.09/locus_to_gene_model/classifier.skops - step.predictions_path: gs://ot_orchestration/releases/27.09/locus_to_gene_predictions + step.model_path: gs://ot_orchestration/releases/24.10.11/locus_to_gene_model/classifier.skops + step.predictions_path: gs://ot_orchestration/releases/24.10.11/locus_to_gene_predictions step.credible_set_path: *valid_credible_set - step.colocalisation_path: gs://ot_orchestration/releases/27.09/colocalisation - step.variant_gene_path: gs://ot_orchestration/releases/27.09/variant_to_gene + step.colocalisation_path: gs://ot_orchestration/releases/24.10.11/colocalisation + step.variant_gene_path: gs://ot_orchestration/releases/24.10.11/variant_to_gene step.study_index_path: *valid_study_index - step.gold_standard_curation_path: gs://ot_orchestration/releases/27.09/locus_to_gene_gold_standard.json + step.gold_standard_curation_path: gs://ot_orchestration/releases/24.10.11/locus_to_gene_gold_standard.json step.gene_interactions_path: gs://genetics_etl_python_playground/static_assets/targets # OTP 233.12 data step.hyperparameters: n_estimators: 100 @@ -227,11 +227,11 @@ nodes: step: locus_to_gene step.run_mode: predict step.model_path: null - step.predictions_path: gs://ot_orchestration/releases/27.09/locus_to_gene_predictions - step.feature_matrix_path: gs://ot_orchestration/releases/27.09/locus_to_gene_feature_matrix + step.predictions_path: gs://ot_orchestration/releases/24.10.11/locus_to_gene_predictions + step.feature_matrix_path: gs://ot_orchestration/releases/24.10.11/locus_to_gene_feature_matrix step.credible_set_path: *valid_credible_set step.study_index_path: *valid_study_index - step.variant_gene_path: gs://ot_orchestration/releases/27.09/variant_to_gene - step.colocalisation_path: gs://ot_orchestration/releases/27.09/colocalisation + step.variant_gene_path: gs://ot_orchestration/releases/24.10.11/variant_to_gene + step.colocalisation_path: gs://ot_orchestration/releases/24.10.11/colocalisation prerequisites: - l2g_train diff --git a/src/ot_orchestration/dags/config/ukb_ppp_eur_finemapping.yaml b/src/ot_orchestration/dags/config/ukb_ppp_eur_finemapping.yaml new file mode 100644 index 0000000..5760a86 --- /dev/null +++ b/src/ot_orchestration/dags/config/ukb_ppp_eur_finemapping.yaml @@ -0,0 +1,37 @@ +nodes: + - id: generate_manifests + kind: Task + prerequisites: [] + params: + collected_loci_path: gs://ukb_ppp_eur_data/study_locus_lb_clumped + manifest_prefix: gs://ukb_ppp_eur_data/finemapping_manifests + output_path: gs://ukb_ppp_eur_data/credible_set_datasets/susie + max_records_per_chunk: 100_000 + + - id: finemapping_batch_job + kind: Task + prerequisites: + - generate_manifests + params: + study_index_path: gs://ukb_ppp_eur_data/study_index + google_batch: + entrypoint: /bin/sh + image: europe-west1-docker.pkg.dev/open-targets-genetics-dev/gentropy-app/gentropy:dev + resource_specs: + cpu_milli: 4000 + memory_mib: 25000 + boot_disk_mib: 20_000 + task_specs: + max_retry_count: 5 + max_run_duration: "7200s" + policy_specs: + machine_type: n1-highmem-4 + + - id: move_finemapping_logs + kind: Task + prerequisites: + - finemapping_batch_job + params: + log_files_in: gs://ukb_ppp_eur_data/credible_set_datasets/susie + log_files_out: gs://ukb_ppp_eur_data/finemapping_logs/ + match_glob: "**.log" diff --git a/src/ot_orchestration/dags/config/ukb_ppp_eur_harmonisation.yaml b/src/ot_orchestration/dags/config/ukb_ppp_eur_harmonisation.yaml index b444da4..2006ded 100644 --- a/src/ot_orchestration/dags/config/ukb_ppp_eur_harmonisation.yaml +++ b/src/ot_orchestration/dags/config/ukb_ppp_eur_harmonisation.yaml @@ -1,18 +1,42 @@ dataproc: - python_main_module: gs://genetics_etl_python_playground/initialisation/gentropy/szsz-update-package-for-dataproc-run/cli.py + python_main_module: gs://genetics_etl_python_playground/initialisation/gentropy/dev/cli.py cluster_metadata: - PACKAGE: gs://genetics_etl_python_playground/initialisation/gentropy/szsz-update-package-for-dataproc-run/gentropy-0.0.0-py3-none-any.whl - cluster_init_script: gs://genetics_etl_python_playground/initialisation/0.0.0/install_dependencies_on_cluster.sh + PACKAGE: gs://genetics_etl_python_playground/initialisation/gentropy/dev/gentropy-0.0.0-py3-none-any.whl + cluster_init_script: gs://genetics_etl_python_playground/initialisation/gentropy/dev/install_dependencies_on_cluster.sh cluster_name: otg-ukb-ppp-eur + autoscaling_policy: otg-etl + nodes: - id: ukb_ppp_eur_sumstat_preprocess kind: Task prerequisites: [] params: + # NOTE: Check documentation to see how to generate raw input files from source + step: ukb_ppp_eur_sumstat_preprocess step.raw_study_index_path_from_tsv: gs://gentropy-tmp/batch/output/ukb_ppp_eur/study_index.tsv step.raw_summary_stats_path: gs://gentropy-tmp/batch/output/ukb_ppp_eur/summary_stats.parquet - step.variant_annotation_path: gs://genetics_etl_python_playground/output/python_etl/parquet/XX.XX/variant_annotation + # all other parameters step.tmp_variant_annotation_path: gs://gentropy-tmp/variant_annotation + step.variant_annotation_path: gs://gnomad_data_2/gnomad_variant_index step.study_index_output_path: gs://ukb_ppp_eur_data/study_index - step.summary_stats_output_path: gs://ukb_ppp_eur_data/summary_stats + step.summary_stats_output_path: gs://ukb_ppp_eur_data/harmonised_summary_statistics + step.session.write_mode: overwrite + + - id: locus_breaker_clumping + kind: Task + prerequisites: + - ukb_ppp_eur_sumstat_preprocess + params: + step: locus_breaker_clumping + step.summary_statistics_input_path: gs://ukb_ppp_eur_data/harmonised_summary_statistics + step.clumped_study_locus_output_path: gs://ukb_ppp_eur_data/study_locus_lb_clumped + step.lbc_baseline_pvalue: 1.0e-5 + step.lbc_distance_cutoff: 250_000 + step.lbc_pvalue_threshold: 1.7e-11 + step.lbc_flanking_distance: 100_000 + step.large_loci_size: 1_500_000 + step.wbc_clump_distance: 500_000 + step.wbc_pvalue_threshold: 1.7e-11 + step.collect_locus: True + step.remove_mhc: True step.session.write_mode: overwrite diff --git a/src/ot_orchestration/dags/genetics_etl.py b/src/ot_orchestration/dags/genetics_etl.py index 10c8210..e82edf8 100644 --- a/src/ot_orchestration/dags/genetics_etl.py +++ b/src/ot_orchestration/dags/genetics_etl.py @@ -15,7 +15,7 @@ from airflow.providers.google.cloud.transfers.gcs_to_gcs import GCSToGCSOperator from airflow.utils.task_group import TaskGroup -from ot_orchestration.operators.vep import ( +from ot_orchestration.operators.batch.vep import ( ConvertVariantsToVcfOperator, VepAnnotateOperator, ) diff --git a/src/ot_orchestration/dags/ukb_ppp_eur_finemapping.py b/src/ot_orchestration/dags/ukb_ppp_eur_finemapping.py new file mode 100644 index 0000000..3f9bd44 --- /dev/null +++ b/src/ot_orchestration/dags/ukb_ppp_eur_finemapping.py @@ -0,0 +1,64 @@ +"""Airflow DAG that uses Google Cloud Batch to run the SuSie Finemapper step for UKB PPP.""" + +from pathlib import Path + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.transfers.gcs_to_gcs import GCSToGCSOperator + +from ot_orchestration.operators.batch.finemapping import ( + FinemappingBatchJobManifestOperator, + FinemappingBatchOperator, +) +from ot_orchestration.utils import ( + chain_dependencies, + find_node_in_config, + read_yaml_config, +) +from ot_orchestration.utils.common import shared_dag_args, shared_dag_kwargs +from ot_orchestration.utils.path import GCSPath + +config = read_yaml_config( + Path(__file__).parent / "config" / "ukb_ppp_eur_finemapping.yaml" +) + + +with DAG( + dag_id=Path(__file__).stem, + description="Open Targets Genetics — Susie Finemap UKB PPP (EUR)", + default_args=shared_dag_args, + **shared_dag_kwargs, +): + tasks = {} + + task_config = find_node_in_config(config["nodes"], "generate_manifests") + generate_manifests = FinemappingBatchJobManifestOperator( + task_id=task_config["id"], + **task_config["params"], + ) + + task_config = find_node_in_config(config["nodes"], "finemapping_batch_job") + finemapping_job = FinemappingBatchOperator.partial( + task_id=task_config["id"], + study_index_path=task_config["params"]["study_index_path"], + google_batch=task_config["google_batch"], + ).expand(manifest=generate_manifests.output) + + task_config = find_node_in_config(config["nodes"], "move_finemapping_logs") + source_file_path = GCSPath(task_config["params"]["log_files_in"]) + destination_file_path = GCSPath(task_config["params"]["log_files_out"]) + + move_logs_job = GCSToGCSOperator( + task_id=task_config["id"], + source_bucket=source_file_path.bucket, + source_object=source_file_path.path, + match_glob=task_config["params"]["match_glob"], + destination_bucket=destination_file_path.bucket, + destination_object=destination_file_path.path, + move_object=True, + ) + + tasks[generate_manifests.task_id] = generate_manifests + tasks[finemapping_job.task_id] = finemapping_job + tasks[move_logs_job.task_id] = move_logs_job + + chain_dependencies(nodes=config["nodes"], tasks_or_task_groups=tasks) diff --git a/src/ot_orchestration/dags/ukb_ppp_finemapping.py b/src/ot_orchestration/dags/ukb_ppp_finemapping.py deleted file mode 100644 index 36e1d9b..0000000 --- a/src/ot_orchestration/dags/ukb_ppp_finemapping.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Airflow DAG that uses Google Cloud Batch to run the SuSie Finemapper step for UKB PPP.""" - -from __future__ import annotations - -from pathlib import Path - -from airflow.decorators import task -from airflow.models.dag import DAG - -from ot_orchestration.templates.finemapping import ( - FinemappingBatchOperator, - generate_manifests_for_finemapping, -) -from ot_orchestration.utils import common - -COLLECTED_LOCI = ( - "gs://genetics-portal-dev-analysis/dc16/output/ukb_ppp/clean_loci.parquet" -) -MANIFEST_PREFIX = "gs://gentropy-tmp/ukb/manifest" -OUTPUT_BASE_PATH = "gs://gentropy-tmp/ukb/output" -STUDY_INDEX_PATH = "gs://ukb_ppp_eur_data/study_index" - - -@task -def generate_manifests(): - return generate_manifests_for_finemapping( - collected_loci=COLLECTED_LOCI, - manifest_prefix=MANIFEST_PREFIX, - output_path=OUTPUT_BASE_PATH, - max_records_per_chunk=100_000, - ) - - -with DAG( - dag_id=Path(__file__).stem, - description="Open Targets Genetics — finemap study loci with SuSie", - default_args=common.shared_dag_args, - **common.shared_dag_kwargs, -) as dag: - ( - FinemappingBatchOperator.partial( - task_id="finemapping_batch_job", study_index_path=STUDY_INDEX_PATH - ).expand(manifest=generate_manifests()) - ) diff --git a/src/ot_orchestration/operators/batch/__init__.py b/src/ot_orchestration/operators/batch/__init__.py new file mode 100644 index 0000000..f61bba1 --- /dev/null +++ b/src/ot_orchestration/operators/batch/__init__.py @@ -0,0 +1 @@ +"""Batch operators.""" diff --git a/src/ot_orchestration/operators/batch/finemapping.py b/src/ot_orchestration/operators/batch/finemapping.py new file mode 100644 index 0000000..d63b3f8 --- /dev/null +++ b/src/ot_orchestration/operators/batch/finemapping.py @@ -0,0 +1,213 @@ +"""Finemapping operators.""" + +import time +from functools import cached_property + +from airflow.models.baseoperator import BaseOperator +from airflow.providers.google.cloud.operators.cloud_batch import ( + CloudBatchSubmitJobOperator, +) +from google.cloud.batch import LifecyclePolicy + +from ot_orchestration.types import GoogleBatchSpecs +from ot_orchestration.utils.batch import ( + create_batch_job, + create_task_env, + create_task_spec, +) +from ot_orchestration.utils.common import GCP_PROJECT_GENETICS, GCP_REGION +from ot_orchestration.utils.path import GCSPath, IOManager, extract_partition_from_blob + + +class FinemappingBatchJobManifestOperator(BaseOperator): + """Generate a manifest for a fine-mapping job.""" + + def __init__( + self, + collected_loci_path: str, + manifest_prefix: str, + output_path: str, + max_records_per_chunk: int = 100_000, + **kwargs, + ): + self.log.info("Using collected loci from %s", collected_loci_path) + self.log.info("Saving manifests to %s", manifest_prefix) + self.log.info("The output of the finemapping will be in %s", output_path) + self.collected_loci_path = GCSPath(collected_loci_path) + self.manifest_prefix = manifest_prefix + self.output_path = output_path + self.max_records_per_chunk = max_records_per_chunk + super().__init__(**kwargs) + + def execute(self, context): + """Execute the operator.""" + return self.generate_manifests_for_finemapping() + + @cached_property + def io_manager(self) -> IOManager: + """Property to get the IOManager to load and dump files.""" + return IOManager() + + def _extract_study_locus_ids_from_blobs(self, collected_loci_path) -> list[str]: + """Get list of loci from the input Google Storage path. + + NOTE: This step requires the dataset to be partitioned only by StudyLocusId!! + """ + self.log.info( + "Extracting studyLocusId from partition names in %s.", + self.collected_loci_path, + ) + client = self.collected_loci_path.client + bucket = client.get_bucket(self.collected_loci_path.bucket) + blobs = bucket.list_blobs(prefix=self.collected_loci_path.path) + all_study_locus_ids = [ + extract_partition_from_blob(blob.name) + for blob in blobs + if "studyLocusId" in blob.name + ] + self.log.info("Found %s studyLocusId(s)", len(all_study_locus_ids)) + return all_study_locus_ids + + def _generate_manifest_rows(self, study_locus_ids: list[str]) -> list[str]: + """This method generates a list containing all rows that will be used to generate the manifests.""" + self.log.info("Concatenating studyLocusId(s) to create manifest rows.") + manifest_rows = [ + f"{self.collected_loci_path}/{locus},{self.output_path}/{locus}" + for locus in study_locus_ids + ] + return manifest_rows + + def _partition_rows_by_range(self, manifest_rows: list[str]) -> list[list[str]]: + """This method partitions rows by pre-defined range.""" + manifest_chunks: list[list[str]] = [] + if self.max_records_per_chunk > len(manifest_rows): + self.log.warning( + "Consider down sampling the `max_records_per_chunk` parameter. Currently it outputs 1 partition." + ) + self.max_records_per_chunk = len(manifest_rows) + self.log.info( + "Partitioning %s manifest rows by %s studyLocusId chunks.", + len(manifest_rows), + self.max_records_per_chunk, + ) + for i in range(0, len(manifest_rows), self.max_records_per_chunk): + chunk = manifest_rows[i : i + self.max_records_per_chunk] + lines = ["study_locus_input,study_locus_output"] + chunk + manifest_chunks.append(lines) + self.log.info("Example output %s", lines[0:2]) + + return manifest_chunks + + def _prepare_batch_task_env( + self, manifest_chunks: list[list[str]] + ) -> list[tuple[int, str, int]]: + """Get the environment that will be used by batch tasks.""" + transfer_objects = [] + env_objects: list[tuple[int, str, int]] = [] + for i, lines in enumerate(manifest_chunks): + self.log.info("Amending %s lines for %s manifest", len(lines) - 1, i) + text = "\n".join(lines) + manifest_path = f"{self.manifest_prefix}/chunk_{i}" + self.log.info("Writing manifest to %s.", manifest_path) + transfer_objects.append((manifest_path, text)) + env_objects.append((i, manifest_path, len(lines) - 1)) + + self.log.info("Writing %s manifests", len(transfer_objects)) + self.io_manager.dump_many( + paths=[t[0] for t in transfer_objects], + objects=[t[1] for t in transfer_objects], + ) + return env_objects + + def generate_manifests_for_finemapping(self) -> list[tuple[int, str, int]]: + """Starting from collected_loci, generate manifests for finemapping, splitting in chunks of at most 100,000 records. + + This step saves the manifests to GCS under the manifest_prefix path with suffix `chunk_{i}`. Each chunk is a csv + file with two columns: study_locus_input and study_locus_output. + + Return: + list[(int, str, int)]: List of tuples, where the first value is index of the manifest, the second value is a path to manifest, and the third is the number of records in that manifest. + """ + all_study_locus_ids = self._extract_study_locus_ids_from_blobs( + self.collected_loci_path + ) + manifest_rows = self._generate_manifest_rows(all_study_locus_ids) + manifest_chunks = self._partition_rows_by_range(manifest_rows) + environments = self._prepare_batch_task_env(manifest_chunks) + return environments + + +class FinemappingBatchOperator(CloudBatchSubmitJobOperator): + def __init__( + self, + manifest: tuple[int, str, int], + study_index_path: str, + google_batch: GoogleBatchSpecs, + **kwargs, + ): + self.study_index_path = study_index_path + self.idx, self.study_locus_manifest_path, self.num_of_tasks = manifest + + super().__init__( + project_id=GCP_PROJECT_GENETICS, + region=GCP_REGION, + job_name=f"finemapping-job-{self.idx}-{time.strftime('%Y%m%d-%H%M%S')}", + job=create_batch_job( + task=create_task_spec( + image=google_batch["image"], + commands=self.susie_finemapping_command, + task_specs=google_batch["task_specs"], + resource_specs=google_batch["resource_specs"], + entrypoint=google_batch["entrypoint"], + lifecycle_policies=[ + LifecyclePolicy( + action=LifecyclePolicy.Action.FAIL_TASK, + action_condition=LifecyclePolicy.ActionCondition( + exit_codes=[50005] # Execution time exceeded. + ), + ) + ], + ), + task_env=create_task_env( + var_list=[ + {"LOCUS_INDEX": str(idx)} for idx in range(0, manifest[2]) + ] + ), + policy_specs=google_batch["policy_specs"], + ), + deferrable=False, + **kwargs, + ) + + @property + def susie_finemapping_command(self) -> list[str]: + """Get the command for running the finemapping batch job.""" + return [ + "-c", + ( + "poetry run gentropy " + "step=susie_finemapping " + f"step.study_index_path={self.study_index_path} " + f"step.study_locus_manifest_path={self.study_locus_manifest_path} " + "step.study_locus_index=$LOCUS_INDEX " + "step.max_causal_snps=10 " + "step.lead_pval_threshold=1e-5 " + "step.purity_mean_r2_threshold=0 " + "step.purity_min_r2_threshold=0.25 " + "step.cs_lbf_thr=2 step.sum_pips=0.99 " + "step.susie_est_tausq=False " + "step.run_carma=False " + "step.run_sumstat_imputation=False " + "step.carma_time_limit=600 " + "step.imputed_r2_threshold=0.9 " + "step.ld_score_threshold=5 " + "step.carma_tau=0.15 " + "step.ld_min_r2=0.8 " + "+step.session.extended_spark_conf={spark.jars:https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-hadoop3-latest.jar} " + "+step.session.extended_spark_conf={spark.dynamicAllocation.enabled:false} " + "+step.session.extended_spark_conf={spark.driver.memory:30g} " + "+step.session.extended_spark_conf={spark.kryoserializer.buffer.max:500m} " + "+step.session.extended_spark_conf={spark.driver.maxResultSize:5g} " + "step.session.write_mode=overwrite" + ), + ] diff --git a/src/ot_orchestration/operators/vep.py b/src/ot_orchestration/operators/batch/vep.py similarity index 100% rename from src/ot_orchestration/operators/vep.py rename to src/ot_orchestration/operators/batch/vep.py diff --git a/src/ot_orchestration/utils/batch.py b/src/ot_orchestration/utils/batch.py index ddf03ce..1b4cb24 100644 --- a/src/ot_orchestration/utils/batch.py +++ b/src/ot_orchestration/utils/batch.py @@ -8,6 +8,7 @@ ComputeResource, Environment, Job, + LifecyclePolicy, LogsPolicy, Runnable, TaskGroup, @@ -51,6 +52,7 @@ def create_task_spec( commands: list[str], resource_specs: BatchResourceSpecs, task_specs: BatchTaskSpecs, + lifecycle_policies: list[LifecyclePolicy] | None = None, **kwargs: Any, ) -> TaskSpec: """Create a task for a Batch job. @@ -60,22 +62,26 @@ def create_task_spec( commands (list[str]): The commands to run in the container. resource_specs (BatchResourceSpecs): The specification of the resources for the task. task_specs (BatchTaskSpecs): The specification of the task. + lifecycle_policies (list[LifecyclePolicy] | None) : Lifecycle policies. **kwargs (Any): Any additional parameter to pass to the container runnable Returns: TaskSpec: The task specification. """ time_duration = time_to_seconds(task_specs["max_run_duration"]) - task = TaskSpec( - runnables=[create_container_runnable(image, commands=commands, **kwargs)], - compute_resource=ComputeResource( + parameters = { + "runnables": [create_container_runnable(image, commands=commands, **kwargs)], + "compute_resource": ComputeResource( cpu_milli=resource_specs["cpu_milli"], memory_mib=resource_specs["memory_mib"], boot_disk_mib=resource_specs["boot_disk_mib"], ), - max_retry_count=task_specs["max_retry_count"], - max_run_duration=f"{time_duration}s", # type: ignore - ) + "max_retry_count": task_specs["max_retry_count"], + "max_run_duration": f"{time_duration}s", # type: ignore + } + if lifecycle_policies: + parameters["lifecycle_policies"] = lifecycle_policies + task = TaskSpec(**parameters) return task @@ -139,14 +145,16 @@ def create_batch_job( def create_task_env(var_list: list[dict[str, Any]]): """This function creates list of batch_v1.Environment objects from provided list of dictionaries.""" - return [Environment(variables=variables) for variables in var_list] + print(f"{var_list=}") + environments = [Environment(variables=variables) for variables in var_list] + return environments def create_task_commands( commands: list[str] | None, params: dict[str, dict[str, Any] | None] ) -> list[str]: """This function prepares list of commands for google batch job from the step configuration.""" - args = convert_params_to_hydra_positional_arg(params=params) + args = convert_params_to_hydra_positional_arg(params=params, dataproc=False) task_commands = [] if commands: task_commands.extend(commands) diff --git a/src/ot_orchestration/utils/dataproc.py b/src/ot_orchestration/utils/dataproc.py index 4930de2..a6532c6 100644 --- a/src/ot_orchestration/utils/dataproc.py +++ b/src/ot_orchestration/utils/dataproc.py @@ -67,7 +67,7 @@ def create_cluster( image_version=GCP_DATAPROC_IMAGE, enable_component_gateway=True, metadata=cluster_metadata, - # idle_delete_ttl=30 * 60, # In seconds. + idle_delete_ttl=30 * 60, # In seconds. init_actions_uris=[cluster_init_script] if cluster_init_script else None, autoscaling_policy=f"projects/{GCP_PROJECT_GENETICS}/regions/{GCP_REGION}/autoscalingPolicies/{autoscaling_policy}", ).make() diff --git a/src/ot_orchestration/utils/path.py b/src/ot_orchestration/utils/path.py index f0e5825..0744cd8 100644 --- a/src/ot_orchestration/utils/path.py +++ b/src/ot_orchestration/utils/path.py @@ -17,6 +17,7 @@ CHUNK_SIZE = 1024 * 256 MAX_N_THREADS = 32 URI_PATTERN = r"^^((?P.*)://)?(?P[(\w)-]+)/(?P([(\w)-/])+)" +PARTITION_REGEX = r"\w*=\w*" class PathSegments(TypedDict): @@ -458,3 +459,24 @@ class ThreadSafetyError(Exception): """Exception raised for errors in thread safety.""" pass + + +def extract_partition_from_blob(blob: storage.Blob | str) -> str: + """Extract partition prefix from a Google Cloud Storage Blob. + + Args: + blob (storage.Blob): Google Cloud Storage Blob. + + Returns: + str: Partition prefix. + """ + if isinstance(blob, str): + name = blob + if isinstance(blob, storage.Blob): + name: str = blob.name # type: ignore + if name.endswith("/"): + name = name[:-1] + _match = re.search(PARTITION_REGEX, name) + if not _match: + raise ValueError("No partition found in %s", name) + return _match.group(0) diff --git a/src/ot_orchestration/utils/utils.py b/src/ot_orchestration/utils/utils.py index 57aa1f6..fc1e46c 100644 --- a/src/ot_orchestration/utils/utils.py +++ b/src/ot_orchestration/utils/utils.py @@ -138,7 +138,7 @@ def chain_dependencies(nodes: list[ConfigNode], tasks_or_task_groups: dict[str, def convert_params_to_hydra_positional_arg( - params: dict[str, Any] | None, + params: dict[str, Any] | None, dataproc: bool = False ) -> list[str]: """Convert configuration parameters to form that can be passed to hydra step positional arguments. @@ -148,6 +148,7 @@ def convert_params_to_hydra_positional_arg( Args: params (dict[str, Any]] | None): Parameters for the step to convert. + dataproc (bool): If true, adds the yarn as a session parameter. Raises: ValueError: When keys passed to the function params dict does not contain the `step.` prefix. @@ -160,8 +161,13 @@ def convert_params_to_hydra_positional_arg( incorrect_param_keys = [key for key in params if "step" not in key] if incorrect_param_keys: raise ValueError(f"Passed incorrect param keys {incorrect_param_keys}") - - return [f"{k}={v}" for k, v in params.items()] + positional_args = [f"{k}={v}" for k, v in params.items()] + if not dataproc: + return positional_args + yarn_session_config = "step.session.spark_uri=yarn" + if yarn_session_config not in positional_args: + positional_args.append(yarn_session_config) + return positional_args def find_node_in_config(config: list[ConfigNode], node_id: str) -> ConfigNode: diff --git a/tests/test_io_manager.py b/tests/test_io_manager.py index 11b6f67..d40606f 100644 --- a/tests/test_io_manager.py +++ b/tests/test_io_manager.py @@ -11,6 +11,7 @@ GCSPath, IOManager, NativePath, + extract_partition_from_blob, ) @@ -188,3 +189,28 @@ def test_bucket_property(self, gcs_path: str, bucket: str) -> None: """Test GCSPath object bucket property.""" gcs_path_obj = GCSPath(gcs_path) assert gcs_path_obj.bucket == bucket + + +@pytest.mark.parametrize( + ["input_blob", "partition"], + [ + pytest.param( + "gs://bucket/prefix/partition=123aa/file.parquet", + "partition=123aa", + id="single partition", + ), + pytest.param( + "gs://bucket/prefix/partition=123aa/otherPartition=123bbb/file.parquet", + "partition=123aa", + id="only first partition is checked", + ), + pytest.param( + "gs://bucket/prefix/partition=123aa/otherPartition=123bbb/file.parquet", + "partition=123aa", + id="only first partition is checked", + ), + ], +) +def test_extract_partition_from_blob(input_blob: str, partition: str) -> None: + """Test extracting partition from a blob.""" + assert extract_partition_from_blob(input_blob) == partition diff --git a/tests/test_utils.py b/tests/test_utils.py index 98e44f4..f95f695 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -24,26 +24,47 @@ def test_time_to_seconds(input: str, output: int) -> None: @pytest.mark.parametrize( - ["input", "output"], + ["input", "output", "is_dataproc_job"], [ pytest.param( {"step": "some_step", "step.b": 2, "+step.c": 3}, ["step=some_step", "step.b=2", "+step.c=3"], + False, id="step configuration", ), pytest.param( {"step": "some_step", "step.b": {"c": 2, "d": 3}}, ["step=some_step", "step.b.c=2", "step.b.d=3"], + False, id="nested dict", marks=pytest.mark.xfail( reason="Structured configuration not supported yet." ), ), + pytest.param( + {"step": "some_step", "step.b": 2, "+step.c": 3}, + ["step=some_step", "step.b=2", "+step.c=3", "step.session.spark_uri=yarn"], + True, + id="Running with dataproc=True adds yarn as a parameter", + ), + pytest.param( + { + "step": "some_step", + "step.b": 2, + "+step.c": 3, + "step.session.spark_uri": "yarn", + }, + ["step=some_step", "step.b=2", "+step.c=3", "step.session.spark_uri=yarn"], + True, + id="Running with dataproc=True and yarn present does not duplicate parameter", + ), ], ) -def test_convert_params_to_hydra_positional_arg(input: dict, output: list[str]) -> None: +def test_convert_params_to_hydra_positional_arg( + input: dict, output: list[str], is_dataproc_job: bool +) -> None: """Test conversion of dictionary to hydra positional arguments.""" - assert convert_params_to_hydra_positional_arg(input) == output + assert convert_params_to_hydra_positional_arg(input, is_dataproc_job) == output def test_find_node_in_config() -> None: