Skip to content

Commit

Permalink
[qob] Support loading RGs from FASTA files
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-goldstein committed Mar 17, 2023
1 parent b371e1c commit b36d77d
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 48 deletions.
2 changes: 1 addition & 1 deletion hail/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ check-pip-lockfile:


python/hailtop/hailctl/deploy.yaml: env/cloud_base env/wheel_cloud_path
python/hailtop/hailctl/deploy.yaml: $(resources) check-pip-lockfile
python/hailtop/hailctl/deploy.yaml: $(resources)
rm -f $@
echo "dataproc:" >> $@
for FILE in $(notdir $(resources)); do \
Expand Down
77 changes: 59 additions & 18 deletions hail/python/hail/backend/service_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Optional, Callable, Awaitable, Mapping, Any, List, Union, Tuple
import abc
import collections
import math
import struct
from hail.expr.expressions.base_expression import Expression
import orjson
Expand Down Expand Up @@ -65,6 +65,12 @@ async def write_str(strm: afs.WritableStream, s: str):
await write_bytes(strm, s.encode('utf-8'))


async def write_str_array(strm: afs.WritableStream, los: List[str]):
await write_int(strm, len(los))
for s in los:
await write_str(strm, s)


class EndOfStream(TransientError):
pass

Expand Down Expand Up @@ -176,6 +182,7 @@ class ServiceBackend(Backend):
PARSE_VCF_METADATA = 7
INDEX_BGEN = 8
IMPORT_FAM = 9
FROM_FASTA_FILE = 10

@staticmethod
async def create(*,
Expand Down Expand Up @@ -291,8 +298,6 @@ def __init__(self,
self.worker_memory = worker_memory
self.name_prefix = name_prefix
self.regions = regions
# Source genome -> [Destination Genome -> Chain file]
self._liftovers: Dict[str, Dict[str, str]] = collections.defaultdict(dict)

def debug_info(self) -> Dict[str, Any]:
return {
Expand Down Expand Up @@ -341,6 +346,9 @@ async def _rpc(self,
timings = Timings()
token = secret_alnum_string()
with TemporaryDirectory(ensure_exists=False) as iodir:
readonly_fuse_buckets = set()
storage_requirement_bytes = 0

with timings.step("write input"):
async with await self._async_fs.create(iodir + '/in') as infile:
nonnull_flag_count = sum(v is not None for v in self.flags.values())
Expand All @@ -353,14 +361,23 @@ async def _rpc(self,
await write_int(infile, len(custom_references))
for reference_config in custom_references:
await write_str(infile, orjson.dumps(reference_config._config).decode('utf-8'))
non_empty_liftovers = {name: liftovers for name, liftovers in self._liftovers.items() if len(liftovers) > 0}
non_empty_liftovers = {rg.name: rg._liftovers for rg in self._references.values() if len(rg._liftovers) > 0}
await write_int(infile, len(non_empty_liftovers))
for source_genome_name, liftovers in non_empty_liftovers.items():
await write_str(infile, source_genome_name)
await write_int(infile, len(liftovers))
for dest_reference_genome, chain_file in liftovers.items():
await write_str(infile, dest_reference_genome)
await write_str(infile, chain_file)
added_sequences = {rg.name: rg._sequence_files for rg in self._references.values() if rg._sequence_files is not None}
await write_int(infile, len(added_sequences))
for rg_name, (fasta_file, index_file) in added_sequences.items():
await write_str(infile, rg_name)
for blob in (fasta_file, index_file):
bucket, path = self._get_bucket_and_path(blob)
readonly_fuse_buckets.add(bucket)
storage_requirement_bytes += await (await self._async_fs.statfile(blob)).size()
await write_str(infile, f'/cloudfuse/{bucket}/{path}')
await write_str(infile, str(self.worker_cores))
await write_str(infile, str(self.worker_memory))
await write_int(infile, len(self.regions))
Expand All @@ -382,6 +399,9 @@ async def _rpc(self,
resources['cpu'] = str(self.driver_cores)
if self.driver_memory is not None:
resources['memory'] = str(self.driver_memory)
if storage_requirement_bytes != 0:
storage_gib = math.ceil(storage_requirement_bytes / 1024 / 1024 / 1024)
resources['storage'] = f'{storage_gib}Gi'

j = bb.create_jvm_job(
jar_spec=self.jar_spec.to_dict(),
Expand All @@ -395,6 +415,7 @@ async def _rpc(self,
resources=resources,
attributes={'name': name + '_driver'},
regions=self.regions,
cloudfuse=[(bucket, f'/cloudfuse/{bucket}', True) for bucket in readonly_fuse_buckets]
)
self._batch = await bb.submit(disable_progress_bar=True)

Expand Down Expand Up @@ -527,7 +548,23 @@ async def inputs(infile, _):
return tblockmatrix._from_json(orjson.loads(resp))

def from_fasta_file(self, name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par):
raise NotImplementedError("ServiceBackend does not support 'from_fasta_file'")
return async_to_blocking(self._from_fasta_file(name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par))

async def _from_fasta_file(self, name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par, *, progress: Optional[BatchProgressBar] = None):
async def inputs(infile, _):
await write_int(infile, ServiceBackend.FROM_FASTA_FILE)
await write_str(infile, tmp_dir())
await write_str(infile, self.billing_project)
await write_str(infile, self.remote_tmpdir)
await write_str(infile, name)
await write_str(infile, fasta_file)
await write_str(infile, index_file)
await write_str_array(infile, x_contigs)
await write_str_array(infile, y_contigs)
await write_str_array(infile, mt_contigs)
await write_str_array(infile, par)
_, resp, _ = await self._rpc('from_fasta_file(...)', inputs, progress=progress)
return orjson.loads(resp)

def load_references_from_dataset(self, path):
return async_to_blocking(self._async_load_references_from_dataset(path))
Expand All @@ -542,22 +579,26 @@ async def inputs(infile, _):
_, resp, _ = await self._rpc('load_references_from_dataset(...)', inputs, progress=progress)
return orjson.loads(resp)

def add_sequence(self, name, fasta_file, index_file):
raise NotImplementedError("ServiceBackend does not support 'add_sequence'")
# Sequence and liftover information is stored on the ReferenceGenome
# and there is no persistent backend to keep in sync.
# Sequence and liftover information are passed on RPC
def add_sequence(self, name, fasta_file, index_file): # pylint: disable=unused-argument
# FIXME Not only should this be in the cloud, it should be in the *right* cloud
for blob in (fasta_file, index_file):
self.validate_file_scheme(blob)

def remove_sequence(self, name): # pylint: disable=unused-argument
pass

def remove_sequence(self, name):
raise NotImplementedError("ServiceBackend does not support 'remove_sequence'")
def _get_bucket_and_path(self, blob_uri):
url = self._async_fs.parse_url(blob_uri)
return '/'.join(url.bucket_parts), url.path

def add_liftover(self, name: str, chain_file: str, dest_reference_genome: str):
if name == dest_reference_genome:
raise ValueError(f'Destination reference genome cannot have the same name as this reference {name}.')
if dest_reference_genome in self._liftovers[name]:
raise ValueError(f'Chain file already exists for destination reference {dest_reference_genome}.')
self._liftovers[name][dest_reference_genome] = chain_file
def add_liftover(self, name: str, chain_file: str, dest_reference_genome: str): # pylint: disable=unused-argument
pass

def remove_liftover(self, name, dest_reference_genome):
assert dest_reference_genome in self._liftovers[name]
del self._liftovers[name][dest_reference_genome]
def remove_liftover(self, name, dest_reference_genome): # pylint: disable=unused-argument
pass

def parse_vcf_metadata(self, path):
return async_to_blocking(self._async_parse_vcf_metadata(path))
Expand Down
6 changes: 4 additions & 2 deletions hail/python/hail/genetics/reference_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ def from_fasta_file(cls, name, fasta_file, index_file,
par_strings = ["{}:{}-{}".format(contig, start, end) for (contig, start, end) in par]
config = Env.backend().from_fasta_file(name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par_strings)

rg = ReferenceGenome._from_config(config, _builtin=True)
rg._sequence_files = (fasta_file, index_file)
rg = ReferenceGenome._from_config(config)
rg.add_sequence(fasta_file, index_file)
return rg

@typecheck_method(dest_reference_genome=reference_genome_type)
Expand Down Expand Up @@ -497,6 +497,8 @@ def add_liftover(self, chain_file, dest_reference_genome):
Env.backend().add_liftover(self.name, chain_file, dest_reference_genome.name)
if dest_reference_genome.name in self._liftovers:
raise KeyError(f"Liftover already exists from {self.name} to {dest_reference_genome.name}.")
if dest_reference_genome.name == self.name:
raise ValueError(f'Destination reference genome cannot have the same name as this reference {self.name}.')
self._liftovers[dest_reference_genome.name] = chain_file


Expand Down
1 change: 0 additions & 1 deletion hail/python/test/hail/genetics/test_reference_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def test_reference_genome():
with hl.TemporaryFilename() as filename:
gr2.write(filename)

@fails_service_backend()
def test_reference_genome_sequence():
gr3 = ReferenceGenome.read(resource("fake_ref_genome.json"))
assert gr3.name == "my_reference_genome"
Expand Down
1 change: 0 additions & 1 deletion hail/python/test/hail/ggplot/test_ggplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from hail.ggplot import *
import numpy as np
import math
from ..helpers import fails_service_backend


def test_geom_point_line_text_col_area():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ class LocalBackend(
withExecuteContext(timer) { ctx =>
val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs.asScala.toArray, yContigs.asScala.toArray, mtContigs.asScala.toArray, parInput.asScala.toArray)
addReference(rg)
rg.toJSONString
}
}
Expand Down
70 changes: 63 additions & 7 deletions hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,20 @@ class ServiceBackend(
): TableStage = {
LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses)
}

def fromFASTAFile(
ctx: ExecuteContext,
name: String,
fastaFile: String,
indexFile: String,
xContigs: Array[String],
yContigs: Array[String],
mtContigs: Array[String],
parInput: Array[String]
): String = {
val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, xContigs, yContigs, mtContigs, parInput)
rg.toJSONString
}
}

class EndOfInputException extends RuntimeException
Expand Down Expand Up @@ -464,9 +478,12 @@ class ServiceBackendSocketAPI2(
private[this] val PARSE_VCF_METADATA = 7
private[this] val INDEX_BGEN = 8
private[this] val IMPORT_FAM = 9
private[this] val FROM_FASTA_FILE = 10

private[this] val dummy = new Array[Byte](8)

private[this] val log = Logger.getLogger(getClass.getName())

def read(bytes: Array[Byte], off: Int, n: Int): Unit = {
assert(off + n <= bytes.length)
var read = 0
Expand Down Expand Up @@ -504,6 +521,17 @@ class ServiceBackendSocketAPI2(

def readString(): String = new String(readBytes(), StandardCharsets.UTF_8)

def readStringArray(): Array[String] = {
val n = readInt()
val arr = new Array[String](n)
var i = 0
while (i < n) {
arr(i) = readString()
i += 1
}
arr
}

def writeBool(b: Boolean): Unit = {
out.write(if (b) 1 else 0)
}
Expand Down Expand Up @@ -556,6 +584,16 @@ class ServiceBackendSocketAPI2(
}
i += 1
}
val nAddedSequences = readInt()
val addedSequences = mutable.Map[String, (String, String)]()
i = 0
while (i < nAddedSequences) {
val rgName = readString()
val fastaFile = readString()
val indexFile = readString()
addedSequences(rgName) = (fastaFile, indexFile)
i += 1
}
val workerCores = readString()
val workerMemory = readString()

Expand Down Expand Up @@ -595,6 +633,9 @@ class ServiceBackendSocketAPI2(
ctx.getReference(sourceGenome).addLiftover(ctx, chainFile, destGenome)
}
}
addedSequences.foreach { case (rg, (fastaFile, indexFile)) =>
ctx.getReference(rg).addSequence(ctx, fastaFile, indexFile)
}
ctx.backendContext = new ServiceBackendContext(sessionId, billingProject, remoteTmpDir, workerCores, workerMemory, regions)
method(ctx)
}
Expand Down Expand Up @@ -660,13 +701,7 @@ class ServiceBackendSocketAPI2(
backend.importFam(_, path, quantPheno, delimiter, missing).getBytes(StandardCharsets.UTF_8)
)
case INDEX_BGEN =>
val nFiles = readInt()
val files = new Array[String](nFiles)
var i = 0
while (i < nFiles) {
files(i) = readString()
i += 1
}
val files = readStringArray()
val nIndexFiles = readInt()
val indexFileMap = mutable.Map[String, String]()
i = 0
Expand Down Expand Up @@ -702,6 +737,27 @@ class ServiceBackendSocketAPI2(
skipInvalidLoci
).getBytes(StandardCharsets.UTF_8)
)
case FROM_FASTA_FILE =>
val name = readString()
val fastaFile = readString()
val indexFile = readString()
val xContigs = readStringArray()
val yContigs = readStringArray()
val mtContigs = readStringArray()
val parInput = readStringArray()
withExecuteContext(
"ServiceBackend.fromFASTAFile",
backend.fromFASTAFile(
_,
name,
fastaFile,
indexFile,
xContigs,
yContigs,
mtContigs,
parInput
).getBytes(StandardCharsets.UTF_8)
)
}
writeBool(true)
writeBytes(result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,6 @@ class SparkBackend(
withExecuteContext(timer) { ctx =>
val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs.asScala.toArray, yContigs.asScala.toArray, mtContigs.asScala.toArray, parInput.asScala.toArray)
addReference(rg)
rg.toJSONString
}
}
Expand Down
14 changes: 10 additions & 4 deletions hail/src/main/scala/is/hail/io/reference/FASTAReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,18 @@ object FASTAReader {
}

def setup(tmpdir: String, fs: FS, fastaFile: String, indexFile: String): String = {
val localFastaFile = ExecuteContext.createTmpPathNoCleanup(tmpdir, "fasta-reader", "fasta")
log.info(s"copying FASTA file at $fastaFile to $localFastaFile")
fs.copyRecode(fastaFile, localFastaFile)
val localFastaFile = if (fastaFile.startsWith("/")) {
val localPath = ExecuteContext.createTmpPathNoCleanup(tmpdir, "fasta-reader", "fasta")
log.info(s"copying FASTA file at $fastaFile to $localPath")
localPath
} else {
fastaFile
}

val localIndexFile = localFastaFile + ".fai"
fs.copyRecode(indexFile, localIndexFile)
if (localIndexFile != indexFile) {
fs.copyRecode(indexFile, localIndexFile)
}

if (!fs.exists(localFastaFile))
fatal(s"Error while copying FASTA file to local file system. Did not find '$localFastaFile'.")
Expand Down
Loading

0 comments on commit b36d77d

Please sign in to comment.