diff --git a/hail/python/hail/backend/service_backend.py b/hail/python/hail/backend/service_backend.py index 3f7cf2fee363..49cf79de2f3c 100644 --- a/hail/python/hail/backend/service_backend.py +++ b/hail/python/hail/backend/service_backend.py @@ -1,6 +1,5 @@ from typing import Dict, Optional, Callable, Awaitable, Mapping, Any, List, Union, Tuple import abc -import collections import struct from hail.expr.expressions.base_expression import Expression import orjson @@ -65,6 +64,12 @@ async def write_str(strm: afs.WritableStream, s: str): await write_bytes(strm, s.encode('utf-8')) +async def write_str_array(strm: afs.WritableStream, los: List[str]): + await write_int(strm, len(los)) + for s in los: + await write_str(strm, s) + + class EndOfStream(TransientError): pass @@ -176,6 +181,7 @@ class ServiceBackend(Backend): PARSE_VCF_METADATA = 7 INDEX_BGEN = 8 IMPORT_FAM = 9 + FROM_FASTA_FILE = 10 @staticmethod async def create(*, @@ -291,8 +297,6 @@ def __init__(self, self.worker_memory = worker_memory self.name_prefix = name_prefix self.regions = regions - # Source genome -> [Destination Genome -> Chain file] - self._liftovers: Dict[str, Dict[str, str]] = collections.defaultdict(dict) def debug_info(self) -> Dict[str, Any]: return { @@ -335,6 +339,8 @@ async def _rpc(self, timings = Timings() token = secret_alnum_string() with TemporaryDirectory(ensure_exists=False) as iodir: + read_only_fuse_buckets = set() + with timings.step("write input"): async with await self._async_fs.create(iodir + '/in') as infile: nonnull_flag_count = sum(v is not None for v in self.flags.values()) @@ -347,7 +353,7 @@ async def _rpc(self, await write_int(infile, len(custom_references)) for reference_config in custom_references: await write_str(infile, orjson.dumps(reference_config._config).decode('utf-8')) - non_empty_liftovers = {name: liftovers for name, liftovers in self._liftovers.items() if len(liftovers) > 0} + non_empty_liftovers = {rg.name: rg._liftovers for rg in self._references.values() if len(rg._liftovers) > 0} await write_int(infile, len(non_empty_liftovers)) for source_genome_name, liftovers in non_empty_liftovers.items(): await write_str(infile, source_genome_name) @@ -355,6 +361,20 @@ async def _rpc(self, for dest_reference_genome, chain_file in liftovers.items(): await write_str(infile, dest_reference_genome) await write_str(infile, chain_file) + added_sequences = {rg.name: rg._sequence_files for rg in self._references.values() if rg._sequence_files is not None} + await write_int(infile, len(added_sequences)) + for rg_name, (fasta_file, index_file) in added_sequences.items(): + fasta_url = self._async_fs.parse_url(fasta_file) + index_url = self._async_fs.parse_url(index_file) + fasta_bucket = '/'.join(fasta_url.bucket_parts) + index_bucket = '/'.join(index_url.bucket_parts) + + read_only_fuse_buckets.add(fasta_bucket) + read_only_fuse_buckets.add(index_bucket) + + await write_str(infile, rg_name) + await write_str(infile, f'/cloudfuse/{fasta_bucket}/{fasta_url.path}') + await write_str(infile, f'/cloudfuse/{index_bucket}/{index_url.path}') await write_str(infile, str(self.worker_cores)) await write_str(infile, str(self.worker_memory)) await write_int(infile, len(self.regions)) @@ -389,6 +409,7 @@ async def _rpc(self, resources=resources, attributes={'name': name + '_driver'}, regions=self.regions, + cloudfuse=[(bucket, f'/cloudfuse/{bucket}', True) for bucket in read_only_fuse_buckets] ) self._batch = await bb.submit(disable_progress_bar=True) @@ -521,7 +542,23 @@ async def inputs(infile, _): return tblockmatrix._from_json(orjson.loads(resp)) def from_fasta_file(self, name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par): - raise NotImplementedError("ServiceBackend does not support 'from_fasta_file'") + return async_to_blocking(self._from_fasta_file(name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par)) + + async def _from_fasta_file(self, name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par, *, progress: Optional[BatchProgressBar] = None): + async def inputs(infile, _): + await write_int(infile, ServiceBackend.FROM_FASTA_FILE) + await write_str(infile, tmp_dir()) + await write_str(infile, self.billing_project) + await write_str(infile, self.remote_tmpdir) + await write_str(infile, name) + await write_str(infile, fasta_file) + await write_str(infile, index_file) + await write_str_array(infile, x_contigs) + await write_str_array(infile, y_contigs) + await write_str_array(infile, mt_contigs) + await write_str_array(infile, par) + _, resp, _ = await self._rpc('from_fasta_file(...)', inputs, progress=progress) + return orjson.loads(resp) def load_references_from_dataset(self, path): return async_to_blocking(self._async_load_references_from_dataset(path)) @@ -536,22 +573,21 @@ async def inputs(infile, _): _, resp, _ = await self._rpc('load_references_from_dataset(...)', inputs, progress=progress) return orjson.loads(resp) - def add_sequence(self, name, fasta_file, index_file): - raise NotImplementedError("ServiceBackend does not support 'add_sequence'") + # Sequence and liftover information is stored on the ReferenceGenome + # and there is no persistent backend to keep in sync so these are a no-op. + # Sequence and liftover information are passed on RPC + def add_sequence(self, name, fasta_file, index_file): # pylint: disable=unused-argument + # TODO Error here if the user does not provide a path in the appropriate cloud + pass - def remove_sequence(self, name): - raise NotImplementedError("ServiceBackend does not support 'remove_sequence'") + def remove_sequence(self, name): # pylint: disable=unused-argument + pass - def add_liftover(self, name: str, chain_file: str, dest_reference_genome: str): - if name == dest_reference_genome: - raise ValueError(f'Destination reference genome cannot have the same name as this reference {name}.') - if dest_reference_genome in self._liftovers[name]: - raise ValueError(f'Chain file already exists for destination reference {dest_reference_genome}.') - self._liftovers[name][dest_reference_genome] = chain_file + def add_liftover(self, name: str, chain_file: str, dest_reference_genome: str): # pylint: disable=unused-argument + pass - def remove_liftover(self, name, dest_reference_genome): - assert dest_reference_genome in self._liftovers[name] - del self._liftovers[name][dest_reference_genome] + def remove_liftover(self, name, dest_reference_genome): # pylint: disable=unused-argument + pass def parse_vcf_metadata(self, path): return async_to_blocking(self._async_parse_vcf_metadata(path)) diff --git a/hail/python/hail/genetics/reference_genome.py b/hail/python/hail/genetics/reference_genome.py index dc522e9af3c9..b12ce6a9e485 100644 --- a/hail/python/hail/genetics/reference_genome.py +++ b/hail/python/hail/genetics/reference_genome.py @@ -419,8 +419,8 @@ def from_fasta_file(cls, name, fasta_file, index_file, par_strings = ["{}:{}-{}".format(contig, start, end) for (contig, start, end) in par] config = Env.backend().from_fasta_file(name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par_strings) - rg = ReferenceGenome._from_config(config, _builtin=True) - rg._sequence_files = (fasta_file, index_file) + rg = ReferenceGenome._from_config(config) + rg.add_sequence(fasta_file, index_file) return rg @typecheck_method(dest_reference_genome=reference_genome_type) @@ -497,6 +497,8 @@ def add_liftover(self, chain_file, dest_reference_genome): Env.backend().add_liftover(self.name, chain_file, dest_reference_genome.name) if dest_reference_genome.name in self._liftovers: raise KeyError(f"Liftover already exists from {self.name} to {dest_reference_genome.name}.") + if dest_reference_genome.name == self.name: + raise ValueError(f'Destination reference genome cannot have the same name as this reference {self.name}.') self._liftovers[dest_reference_genome.name] = chain_file diff --git a/hail/python/test/hail/genetics/test_reference_genome.py b/hail/python/test/hail/genetics/test_reference_genome.py index 410b08826bc8..586d494fb76b 100644 --- a/hail/python/test/hail/genetics/test_reference_genome.py +++ b/hail/python/test/hail/genetics/test_reference_genome.py @@ -35,7 +35,6 @@ def test_reference_genome(): with hl.TemporaryFilename() as filename: gr2.write(filename) -@fails_service_backend() def test_reference_genome_sequence(): gr3 = ReferenceGenome.read(resource("fake_ref_genome.json")) assert gr3.name == "my_reference_genome" diff --git a/hail/python/test/hail/ggplot/test_ggplot.py b/hail/python/test/hail/ggplot/test_ggplot.py index 84d99295a6d9..b2349c7455f2 100644 --- a/hail/python/test/hail/ggplot/test_ggplot.py +++ b/hail/python/test/hail/ggplot/test_ggplot.py @@ -2,7 +2,6 @@ from hail.ggplot import * import numpy as np import math -from ..helpers import fails_service_backend def test_geom_point_line_text_col_area(): diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index 2142a1694292..1053a9ef0cc5 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -274,7 +274,6 @@ class LocalBackend( withExecuteContext(timer) { ctx => val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, xContigs.asScala.toArray, yContigs.asScala.toArray, mtContigs.asScala.toArray, parInput.asScala.toArray) - addReference(rg) rg.toJSONString } } diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index bf96b8e80b68..3e623b475152 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -395,6 +395,20 @@ class ServiceBackend( info(s"Number of BGEN files indexed: ${ files.size }") "null" } + + def fromFASTAFile( + ctx: ExecuteContext, + name: String, + fastaFile: String, + indexFile: String, + xContigs: Array[String], + yContigs: Array[String], + mtContigs: Array[String], + parInput: Array[String] + ): String = { + val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, xContigs, yContigs, mtContigs, parInput) + rg.toJSONString + } } class EndOfInputException extends RuntimeException @@ -464,9 +478,12 @@ class ServiceBackendSocketAPI2( private[this] val PARSE_VCF_METADATA = 7 private[this] val INDEX_BGEN = 8 private[this] val IMPORT_FAM = 9 + private[this] val FROM_FASTA_FILE = 10 private[this] val dummy = new Array[Byte](8) + private[this] val log = Logger.getLogger(getClass.getName()) + def read(bytes: Array[Byte], off: Int, n: Int): Unit = { assert(off + n <= bytes.length) var read = 0 @@ -504,6 +521,17 @@ class ServiceBackendSocketAPI2( def readString(): String = new String(readBytes(), StandardCharsets.UTF_8) + def readStringArray(): Array[String] = { + val n = readInt() + val arr = new Array[String](n) + var i = 0 + while (i < n) { + arr(i) = readString() + i += 1 + } + arr + } + def writeBool(b: Boolean): Unit = { out.write(if (b) 1 else 0) } @@ -556,6 +584,16 @@ class ServiceBackendSocketAPI2( } i += 1 } + val nAddedSequences = readInt() + val addedSequences = mutable.Map[String, (String, String)]() + i = 0 + while (i < nAddedSequences) { + val rgName = readString() + val fastaFile = readString() + val indexFile = readString() + addedSequences(rgName) = (fastaFile, indexFile) + i += 1 + } val workerCores = readString() val workerMemory = readString() @@ -595,6 +633,9 @@ class ServiceBackendSocketAPI2( ctx.getReference(sourceGenome).addLiftover(ctx, chainFile, destGenome) } } + addedSequences.foreach { case (rg, (fastaFile, indexFile)) => + ctx.getReference(rg).addSequence(ctx, fastaFile, indexFile) + } ctx.backendContext = new ServiceBackendContext(sessionId, billingProject, remoteTmpDir, workerCores, workerMemory, regions) method(ctx) } @@ -660,13 +701,7 @@ class ServiceBackendSocketAPI2( backend.importFam(_, path, quantPheno, delimiter, missing).getBytes(StandardCharsets.UTF_8) ) case INDEX_BGEN => - val nFiles = readInt() - val files = new Array[String](nFiles) - var i = 0 - while (i < nFiles) { - files(i) = readString() - i += 1 - } + val files = readStringArray() val nIndexFiles = readInt() val indexFileMap = mutable.Map[String, String]() i = 0 @@ -702,6 +737,27 @@ class ServiceBackendSocketAPI2( skipInvalidLoci ).getBytes(StandardCharsets.UTF_8) ) + case FROM_FASTA_FILE => + val name = readString() + val fastaFile = readString() + val indexFile = readString() + val xContigs = readStringArray() + val yContigs = readStringArray() + val mtContigs = readStringArray() + val parInput = readStringArray() + withExecuteContext( + "ServiceBackend.fromFASTAFile", + backend.fromFASTAFile( + _, + name, + fastaFile, + indexFile, + xContigs, + yContigs, + mtContigs, + parInput + ).getBytes(StandardCharsets.UTF_8) + ) } writeBool(true) writeBytes(result) diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index 5b05607242aa..fc725e4c6644 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -646,7 +646,6 @@ class SparkBackend( withExecuteContext(timer) { ctx => val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, xContigs.asScala.toArray, yContigs.asScala.toArray, mtContigs.asScala.toArray, parInput.asScala.toArray) - addReference(rg) rg.toJSONString } } diff --git a/hail/src/main/scala/is/hail/io/reference/FASTAReader.scala b/hail/src/main/scala/is/hail/io/reference/FASTAReader.scala index 7f4ac39d5ecf..52e7b890a263 100644 --- a/hail/src/main/scala/is/hail/io/reference/FASTAReader.scala +++ b/hail/src/main/scala/is/hail/io/reference/FASTAReader.scala @@ -37,12 +37,18 @@ object FASTAReader { } def setup(tmpdir: String, fs: FS, fastaFile: String, indexFile: String): String = { - val localFastaFile = ExecuteContext.createTmpPathNoCleanup(tmpdir, "fasta-reader", "fasta") - log.info(s"copying FASTA file at $fastaFile to $localFastaFile") - fs.copyRecode(fastaFile, localFastaFile) + val localFastaFile = if (fastaFile.startsWith("/")) { + val localPath = ExecuteContext.createTmpPathNoCleanup(tmpdir, "fasta-reader", "fasta") + log.info(s"copying FASTA file at $fastaFile to $localPath") + localPath + } else { + fastaFile + } val localIndexFile = localFastaFile + ".fai" - fs.copyRecode(indexFile, localIndexFile) + if (localIndexFile != indexFile) { + fs.copyRecode(indexFile, localIndexFile) + } if (!fs.exists(localFastaFile)) fatal(s"Error while copying FASTA file to local file system. Did not find '$localFastaFile'.") diff --git a/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala b/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala index 2abb3e8b27c6..7d7ee2ddaa9b 100644 --- a/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala +++ b/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala @@ -331,18 +331,18 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St val tmpdir = ctx.localTmpdir val fs = ctx.fs + + // If this exists check passes the user is considered to have access to these files + // and is permitted to use cached versions corresponding to these URIs if (!fs.exists(fastaFile)) - fatal(s"FASTA file '$fastaFile' does not exist.") + fatal(s"FASTA file '$fastaFile' does not exist or you do not have access.") if (!fs.exists(indexFile)) - fatal(s"FASTA index file '$indexFile' does not exist.") + fatal(s"FASTA index file '$indexFile' does not exist or you do not have access.") fastaFilePath = fastaFile fastaIndexPath = indexFile // assumption, fastaFile and indexFile will not move or change for the entire duration of a hail pipeline - val localIndexFile = ExecuteContext.createTmpPathNoCleanup(tmpdir, "fasta-reader-add-seq", "fai") - fs.copyRecode(indexFile, localIndexFile) - - val index = new FastaSequenceIndex(new java.io.File(uriPath(localIndexFile))) + val index = using(fs.open(indexFile))(new FastaSequenceIndex(_)) val missingContigs = contigs.filterNot(index.hasIndexEntry) if (missingContigs.nonEmpty) @@ -578,9 +578,7 @@ object ReferenceGenome { if (!fs.exists(indexFile)) fatal(s"FASTA index file '$indexFile' does not exist.") - val localIndexFile = ExecuteContext.createTmpPathNoCleanup(tmpdir, "fasta-reader-from-fasta", "fai") - fs.copyRecode(indexFile, localIndexFile) - val index = new FastaSequenceIndex(new java.io.File(uriPath(localIndexFile))) + val index = using(fs.open(indexFile))(new FastaSequenceIndex(_)) val contigs = new BoxedArrayBuilder[String] val lengths = new BoxedArrayBuilder[(String, Int)] @@ -592,9 +590,7 @@ object ReferenceGenome { lengths += (contig -> length.toInt) } - val rg = ReferenceGenome(name, contigs.result(), lengths.result().toMap, xContigs, yContigs, mtContigs, parInput) - rg.addSequence(ctx, fastaFile, indexFile) - rg + ReferenceGenome(name, contigs.result(), lengths.result().toMap, xContigs, yContigs, mtContigs, parInput) } def readReferences(fs: FS, path: String): Array[ReferenceGenome] = {