From b36d77d59d556db8bb60930c631e185f99b03fdf Mon Sep 17 00:00:00 2001 From: Daniel Goldstein Date: Mon, 27 Feb 2023 17:16:43 -0500 Subject: [PATCH] [qob] Support loading RGs from FASTA files --- hail/Makefile | 2 +- hail/python/hail/backend/service_backend.py | 77 ++++++++++++++----- hail/python/hail/genetics/reference_genome.py | 6 +- .../hail/genetics/test_reference_genome.py | 1 - hail/python/test/hail/ggplot/test_ggplot.py | 1 - .../is/hail/backend/local/LocalBackend.scala | 1 - .../hail/backend/service/ServiceBackend.scala | 70 +++++++++++++++-- .../is/hail/backend/spark/SparkBackend.scala | 1 - .../is/hail/io/reference/FASTAReader.scala | 14 +++- .../is/hail/variant/ReferenceGenome.scala | 17 ++-- 10 files changed, 142 insertions(+), 48 deletions(-) diff --git a/hail/Makefile b/hail/Makefile index 02d21416c8fd..87e02970f90a 100644 --- a/hail/Makefile +++ b/hail/Makefile @@ -288,7 +288,7 @@ check-pip-lockfile: python/hailtop/hailctl/deploy.yaml: env/cloud_base env/wheel_cloud_path -python/hailtop/hailctl/deploy.yaml: $(resources) check-pip-lockfile +python/hailtop/hailctl/deploy.yaml: $(resources) rm -f $@ echo "dataproc:" >> $@ for FILE in $(notdir $(resources)); do \ diff --git a/hail/python/hail/backend/service_backend.py b/hail/python/hail/backend/service_backend.py index f9c5bf505b17..7bf4583bb475 100644 --- a/hail/python/hail/backend/service_backend.py +++ b/hail/python/hail/backend/service_backend.py @@ -1,6 +1,6 @@ from typing import Dict, Optional, Callable, Awaitable, Mapping, Any, List, Union, Tuple import abc -import collections +import math import struct from hail.expr.expressions.base_expression import Expression import orjson @@ -65,6 +65,12 @@ async def write_str(strm: afs.WritableStream, s: str): await write_bytes(strm, s.encode('utf-8')) +async def write_str_array(strm: afs.WritableStream, los: List[str]): + await write_int(strm, len(los)) + for s in los: + await write_str(strm, s) + + class EndOfStream(TransientError): pass @@ -176,6 +182,7 @@ class ServiceBackend(Backend): PARSE_VCF_METADATA = 7 INDEX_BGEN = 8 IMPORT_FAM = 9 + FROM_FASTA_FILE = 10 @staticmethod async def create(*, @@ -291,8 +298,6 @@ def __init__(self, self.worker_memory = worker_memory self.name_prefix = name_prefix self.regions = regions - # Source genome -> [Destination Genome -> Chain file] - self._liftovers: Dict[str, Dict[str, str]] = collections.defaultdict(dict) def debug_info(self) -> Dict[str, Any]: return { @@ -341,6 +346,9 @@ async def _rpc(self, timings = Timings() token = secret_alnum_string() with TemporaryDirectory(ensure_exists=False) as iodir: + readonly_fuse_buckets = set() + storage_requirement_bytes = 0 + with timings.step("write input"): async with await self._async_fs.create(iodir + '/in') as infile: nonnull_flag_count = sum(v is not None for v in self.flags.values()) @@ -353,7 +361,7 @@ async def _rpc(self, await write_int(infile, len(custom_references)) for reference_config in custom_references: await write_str(infile, orjson.dumps(reference_config._config).decode('utf-8')) - non_empty_liftovers = {name: liftovers for name, liftovers in self._liftovers.items() if len(liftovers) > 0} + non_empty_liftovers = {rg.name: rg._liftovers for rg in self._references.values() if len(rg._liftovers) > 0} await write_int(infile, len(non_empty_liftovers)) for source_genome_name, liftovers in non_empty_liftovers.items(): await write_str(infile, source_genome_name) @@ -361,6 +369,15 @@ async def _rpc(self, for dest_reference_genome, chain_file in liftovers.items(): await write_str(infile, dest_reference_genome) await write_str(infile, chain_file) + added_sequences = {rg.name: rg._sequence_files for rg in self._references.values() if rg._sequence_files is not None} + await write_int(infile, len(added_sequences)) + for rg_name, (fasta_file, index_file) in added_sequences.items(): + await write_str(infile, rg_name) + for blob in (fasta_file, index_file): + bucket, path = self._get_bucket_and_path(blob) + readonly_fuse_buckets.add(bucket) + storage_requirement_bytes += await (await self._async_fs.statfile(blob)).size() + await write_str(infile, f'/cloudfuse/{bucket}/{path}') await write_str(infile, str(self.worker_cores)) await write_str(infile, str(self.worker_memory)) await write_int(infile, len(self.regions)) @@ -382,6 +399,9 @@ async def _rpc(self, resources['cpu'] = str(self.driver_cores) if self.driver_memory is not None: resources['memory'] = str(self.driver_memory) + if storage_requirement_bytes != 0: + storage_gib = math.ceil(storage_requirement_bytes / 1024 / 1024 / 1024) + resources['storage'] = f'{storage_gib}Gi' j = bb.create_jvm_job( jar_spec=self.jar_spec.to_dict(), @@ -395,6 +415,7 @@ async def _rpc(self, resources=resources, attributes={'name': name + '_driver'}, regions=self.regions, + cloudfuse=[(bucket, f'/cloudfuse/{bucket}', True) for bucket in readonly_fuse_buckets] ) self._batch = await bb.submit(disable_progress_bar=True) @@ -527,7 +548,23 @@ async def inputs(infile, _): return tblockmatrix._from_json(orjson.loads(resp)) def from_fasta_file(self, name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par): - raise NotImplementedError("ServiceBackend does not support 'from_fasta_file'") + return async_to_blocking(self._from_fasta_file(name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par)) + + async def _from_fasta_file(self, name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par, *, progress: Optional[BatchProgressBar] = None): + async def inputs(infile, _): + await write_int(infile, ServiceBackend.FROM_FASTA_FILE) + await write_str(infile, tmp_dir()) + await write_str(infile, self.billing_project) + await write_str(infile, self.remote_tmpdir) + await write_str(infile, name) + await write_str(infile, fasta_file) + await write_str(infile, index_file) + await write_str_array(infile, x_contigs) + await write_str_array(infile, y_contigs) + await write_str_array(infile, mt_contigs) + await write_str_array(infile, par) + _, resp, _ = await self._rpc('from_fasta_file(...)', inputs, progress=progress) + return orjson.loads(resp) def load_references_from_dataset(self, path): return async_to_blocking(self._async_load_references_from_dataset(path)) @@ -542,22 +579,26 @@ async def inputs(infile, _): _, resp, _ = await self._rpc('load_references_from_dataset(...)', inputs, progress=progress) return orjson.loads(resp) - def add_sequence(self, name, fasta_file, index_file): - raise NotImplementedError("ServiceBackend does not support 'add_sequence'") + # Sequence and liftover information is stored on the ReferenceGenome + # and there is no persistent backend to keep in sync. + # Sequence and liftover information are passed on RPC + def add_sequence(self, name, fasta_file, index_file): # pylint: disable=unused-argument + # FIXME Not only should this be in the cloud, it should be in the *right* cloud + for blob in (fasta_file, index_file): + self.validate_file_scheme(blob) + + def remove_sequence(self, name): # pylint: disable=unused-argument + pass - def remove_sequence(self, name): - raise NotImplementedError("ServiceBackend does not support 'remove_sequence'") + def _get_bucket_and_path(self, blob_uri): + url = self._async_fs.parse_url(blob_uri) + return '/'.join(url.bucket_parts), url.path - def add_liftover(self, name: str, chain_file: str, dest_reference_genome: str): - if name == dest_reference_genome: - raise ValueError(f'Destination reference genome cannot have the same name as this reference {name}.') - if dest_reference_genome in self._liftovers[name]: - raise ValueError(f'Chain file already exists for destination reference {dest_reference_genome}.') - self._liftovers[name][dest_reference_genome] = chain_file + def add_liftover(self, name: str, chain_file: str, dest_reference_genome: str): # pylint: disable=unused-argument + pass - def remove_liftover(self, name, dest_reference_genome): - assert dest_reference_genome in self._liftovers[name] - del self._liftovers[name][dest_reference_genome] + def remove_liftover(self, name, dest_reference_genome): # pylint: disable=unused-argument + pass def parse_vcf_metadata(self, path): return async_to_blocking(self._async_parse_vcf_metadata(path)) diff --git a/hail/python/hail/genetics/reference_genome.py b/hail/python/hail/genetics/reference_genome.py index dc522e9af3c9..b12ce6a9e485 100644 --- a/hail/python/hail/genetics/reference_genome.py +++ b/hail/python/hail/genetics/reference_genome.py @@ -419,8 +419,8 @@ def from_fasta_file(cls, name, fasta_file, index_file, par_strings = ["{}:{}-{}".format(contig, start, end) for (contig, start, end) in par] config = Env.backend().from_fasta_file(name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par_strings) - rg = ReferenceGenome._from_config(config, _builtin=True) - rg._sequence_files = (fasta_file, index_file) + rg = ReferenceGenome._from_config(config) + rg.add_sequence(fasta_file, index_file) return rg @typecheck_method(dest_reference_genome=reference_genome_type) @@ -497,6 +497,8 @@ def add_liftover(self, chain_file, dest_reference_genome): Env.backend().add_liftover(self.name, chain_file, dest_reference_genome.name) if dest_reference_genome.name in self._liftovers: raise KeyError(f"Liftover already exists from {self.name} to {dest_reference_genome.name}.") + if dest_reference_genome.name == self.name: + raise ValueError(f'Destination reference genome cannot have the same name as this reference {self.name}.') self._liftovers[dest_reference_genome.name] = chain_file diff --git a/hail/python/test/hail/genetics/test_reference_genome.py b/hail/python/test/hail/genetics/test_reference_genome.py index 410b08826bc8..586d494fb76b 100644 --- a/hail/python/test/hail/genetics/test_reference_genome.py +++ b/hail/python/test/hail/genetics/test_reference_genome.py @@ -35,7 +35,6 @@ def test_reference_genome(): with hl.TemporaryFilename() as filename: gr2.write(filename) -@fails_service_backend() def test_reference_genome_sequence(): gr3 = ReferenceGenome.read(resource("fake_ref_genome.json")) assert gr3.name == "my_reference_genome" diff --git a/hail/python/test/hail/ggplot/test_ggplot.py b/hail/python/test/hail/ggplot/test_ggplot.py index 84d99295a6d9..b2349c7455f2 100644 --- a/hail/python/test/hail/ggplot/test_ggplot.py +++ b/hail/python/test/hail/ggplot/test_ggplot.py @@ -2,7 +2,6 @@ from hail.ggplot import * import numpy as np import math -from ..helpers import fails_service_backend def test_geom_point_line_text_col_area(): diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index f22cb154226d..650f02f89724 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -274,7 +274,6 @@ class LocalBackend( withExecuteContext(timer) { ctx => val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, xContigs.asScala.toArray, yContigs.asScala.toArray, mtContigs.asScala.toArray, parInput.asScala.toArray) - addReference(rg) rg.toJSONString } } diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index d668ad573363..228b23a50e37 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -395,6 +395,20 @@ class ServiceBackend( ): TableStage = { LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses) } + + def fromFASTAFile( + ctx: ExecuteContext, + name: String, + fastaFile: String, + indexFile: String, + xContigs: Array[String], + yContigs: Array[String], + mtContigs: Array[String], + parInput: Array[String] + ): String = { + val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, xContigs, yContigs, mtContigs, parInput) + rg.toJSONString + } } class EndOfInputException extends RuntimeException @@ -464,9 +478,12 @@ class ServiceBackendSocketAPI2( private[this] val PARSE_VCF_METADATA = 7 private[this] val INDEX_BGEN = 8 private[this] val IMPORT_FAM = 9 + private[this] val FROM_FASTA_FILE = 10 private[this] val dummy = new Array[Byte](8) + private[this] val log = Logger.getLogger(getClass.getName()) + def read(bytes: Array[Byte], off: Int, n: Int): Unit = { assert(off + n <= bytes.length) var read = 0 @@ -504,6 +521,17 @@ class ServiceBackendSocketAPI2( def readString(): String = new String(readBytes(), StandardCharsets.UTF_8) + def readStringArray(): Array[String] = { + val n = readInt() + val arr = new Array[String](n) + var i = 0 + while (i < n) { + arr(i) = readString() + i += 1 + } + arr + } + def writeBool(b: Boolean): Unit = { out.write(if (b) 1 else 0) } @@ -556,6 +584,16 @@ class ServiceBackendSocketAPI2( } i += 1 } + val nAddedSequences = readInt() + val addedSequences = mutable.Map[String, (String, String)]() + i = 0 + while (i < nAddedSequences) { + val rgName = readString() + val fastaFile = readString() + val indexFile = readString() + addedSequences(rgName) = (fastaFile, indexFile) + i += 1 + } val workerCores = readString() val workerMemory = readString() @@ -595,6 +633,9 @@ class ServiceBackendSocketAPI2( ctx.getReference(sourceGenome).addLiftover(ctx, chainFile, destGenome) } } + addedSequences.foreach { case (rg, (fastaFile, indexFile)) => + ctx.getReference(rg).addSequence(ctx, fastaFile, indexFile) + } ctx.backendContext = new ServiceBackendContext(sessionId, billingProject, remoteTmpDir, workerCores, workerMemory, regions) method(ctx) } @@ -660,13 +701,7 @@ class ServiceBackendSocketAPI2( backend.importFam(_, path, quantPheno, delimiter, missing).getBytes(StandardCharsets.UTF_8) ) case INDEX_BGEN => - val nFiles = readInt() - val files = new Array[String](nFiles) - var i = 0 - while (i < nFiles) { - files(i) = readString() - i += 1 - } + val files = readStringArray() val nIndexFiles = readInt() val indexFileMap = mutable.Map[String, String]() i = 0 @@ -702,6 +737,27 @@ class ServiceBackendSocketAPI2( skipInvalidLoci ).getBytes(StandardCharsets.UTF_8) ) + case FROM_FASTA_FILE => + val name = readString() + val fastaFile = readString() + val indexFile = readString() + val xContigs = readStringArray() + val yContigs = readStringArray() + val mtContigs = readStringArray() + val parInput = readStringArray() + withExecuteContext( + "ServiceBackend.fromFASTAFile", + backend.fromFASTAFile( + _, + name, + fastaFile, + indexFile, + xContigs, + yContigs, + mtContigs, + parInput + ).getBytes(StandardCharsets.UTF_8) + ) } writeBool(true) writeBytes(result) diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index e6655c420771..1301d8588655 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -646,7 +646,6 @@ class SparkBackend( withExecuteContext(timer) { ctx => val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, xContigs.asScala.toArray, yContigs.asScala.toArray, mtContigs.asScala.toArray, parInput.asScala.toArray) - addReference(rg) rg.toJSONString } } diff --git a/hail/src/main/scala/is/hail/io/reference/FASTAReader.scala b/hail/src/main/scala/is/hail/io/reference/FASTAReader.scala index 7f4ac39d5ecf..52e7b890a263 100644 --- a/hail/src/main/scala/is/hail/io/reference/FASTAReader.scala +++ b/hail/src/main/scala/is/hail/io/reference/FASTAReader.scala @@ -37,12 +37,18 @@ object FASTAReader { } def setup(tmpdir: String, fs: FS, fastaFile: String, indexFile: String): String = { - val localFastaFile = ExecuteContext.createTmpPathNoCleanup(tmpdir, "fasta-reader", "fasta") - log.info(s"copying FASTA file at $fastaFile to $localFastaFile") - fs.copyRecode(fastaFile, localFastaFile) + val localFastaFile = if (fastaFile.startsWith("/")) { + val localPath = ExecuteContext.createTmpPathNoCleanup(tmpdir, "fasta-reader", "fasta") + log.info(s"copying FASTA file at $fastaFile to $localPath") + localPath + } else { + fastaFile + } val localIndexFile = localFastaFile + ".fai" - fs.copyRecode(indexFile, localIndexFile) + if (localIndexFile != indexFile) { + fs.copyRecode(indexFile, localIndexFile) + } if (!fs.exists(localFastaFile)) fatal(s"Error while copying FASTA file to local file system. Did not find '$localFastaFile'.") diff --git a/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala b/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala index 2abb3e8b27c6..0284be23ee81 100644 --- a/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala +++ b/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala @@ -332,17 +332,14 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St val tmpdir = ctx.localTmpdir val fs = ctx.fs if (!fs.exists(fastaFile)) - fatal(s"FASTA file '$fastaFile' does not exist.") + fatal(s"FASTA file '$fastaFile' does not exist or you do not have access.") if (!fs.exists(indexFile)) - fatal(s"FASTA index file '$indexFile' does not exist.") + fatal(s"FASTA index file '$indexFile' does not exist or you do not have access.") fastaFilePath = fastaFile fastaIndexPath = indexFile // assumption, fastaFile and indexFile will not move or change for the entire duration of a hail pipeline - val localIndexFile = ExecuteContext.createTmpPathNoCleanup(tmpdir, "fasta-reader-add-seq", "fai") - fs.copyRecode(indexFile, localIndexFile) - - val index = new FastaSequenceIndex(new java.io.File(uriPath(localIndexFile))) + val index = using(fs.open(indexFile))(new FastaSequenceIndex(_)) val missingContigs = contigs.filterNot(index.hasIndexEntry) if (missingContigs.nonEmpty) @@ -578,9 +575,7 @@ object ReferenceGenome { if (!fs.exists(indexFile)) fatal(s"FASTA index file '$indexFile' does not exist.") - val localIndexFile = ExecuteContext.createTmpPathNoCleanup(tmpdir, "fasta-reader-from-fasta", "fai") - fs.copyRecode(indexFile, localIndexFile) - val index = new FastaSequenceIndex(new java.io.File(uriPath(localIndexFile))) + val index = using(fs.open(indexFile))(new FastaSequenceIndex(_)) val contigs = new BoxedArrayBuilder[String] val lengths = new BoxedArrayBuilder[(String, Int)] @@ -592,9 +587,7 @@ object ReferenceGenome { lengths += (contig -> length.toInt) } - val rg = ReferenceGenome(name, contigs.result(), lengths.result().toMap, xContigs, yContigs, mtContigs, parInput) - rg.addSequence(ctx, fastaFile, indexFile) - rg + ReferenceGenome(name, contigs.result(), lengths.result().toMap, xContigs, yContigs, mtContigs, parInput) } def readReferences(fs: FS, path: String): Array[ReferenceGenome] = {