Skip to content

Commit

Permalink
[qob] Support loading RGs from FASTA files
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-goldstein committed Mar 1, 2023
1 parent 9e0081c commit 8d89cdc
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 42 deletions.
60 changes: 42 additions & 18 deletions hail/python/hail/backend/service_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, Optional, Callable, Awaitable, Mapping, Any, List, Union, Tuple
import abc
import collections
import struct
from hail.expr.expressions.base_expression import Expression
import orjson
Expand Down Expand Up @@ -64,6 +63,12 @@ async def write_str(strm: afs.WritableStream, s: str):
await write_bytes(strm, s.encode('utf-8'))


async def write_str_array(strm: afs.WritableStream, los: List[str]):
await write_int(strm, len(los))
for s in los:
await write_str(strm, s)


class EndOfStream(TransientError):
pass

Expand Down Expand Up @@ -175,6 +180,7 @@ class ServiceBackend(Backend):
PARSE_VCF_METADATA = 7
INDEX_BGEN = 8
IMPORT_FAM = 9
FROM_FASTA_FILE = 10

@staticmethod
async def create(*,
Expand Down Expand Up @@ -275,8 +281,6 @@ def __init__(self,
self.worker_cores = worker_cores
self.worker_memory = worker_memory
self.name_prefix = name_prefix
# Source genome -> [Destination Genome -> Chain file]
self._liftovers: Dict[str, Dict[str, str]] = collections.defaultdict(dict)

def debug_info(self) -> Dict[str, Any]:
return {
Expand Down Expand Up @@ -330,14 +334,20 @@ async def _rpc(self,
await write_int(infile, len(custom_references))
for reference_config in custom_references:
await write_str(infile, orjson.dumps(reference_config._config).decode('utf-8'))
non_empty_liftovers = {name: liftovers for name, liftovers in self._liftovers.items() if len(liftovers) > 0}
non_empty_liftovers = {rg.name: rg._liftovers for rg in self._references.values() if len(rg._liftovers) > 0}
await write_int(infile, len(non_empty_liftovers))
for source_genome_name, liftovers in non_empty_liftovers.items():
await write_str(infile, source_genome_name)
await write_int(infile, len(liftovers))
for dest_reference_genome, chain_file in liftovers.items():
await write_str(infile, dest_reference_genome)
await write_str(infile, chain_file)
added_sequences = {rg.name: rg._sequence_files for rg in self._references.values() if rg._sequence_files is not None}
await write_int(infile, len(added_sequences))
for rg_name, (fasta_file, index_file) in added_sequences.items():
await write_str(infile, rg_name)
await write_str(infile, fasta_file)
await write_str(infile, index_file)
await write_str(infile, str(self.worker_cores))
await write_str(infile, str(self.worker_memory))
await inputs(infile, token)
Expand Down Expand Up @@ -500,7 +510,23 @@ async def inputs(infile, _):
return tblockmatrix._from_json(orjson.loads(resp))

def from_fasta_file(self, name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par):
raise NotImplementedError("ServiceBackend does not support 'from_fasta_file'")
return async_to_blocking(self._from_fasta_file(name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par))

async def _from_fasta_file(self, name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par, *, progress: Optional[BatchProgressBar] = None):
async def inputs(infile, _):
await write_int(infile, ServiceBackend.FROM_FASTA_FILE)
await write_str(infile, tmp_dir())
await write_str(infile, self.billing_project)
await write_str(infile, self.remote_tmpdir)
await write_str(infile, name)
await write_str(infile, fasta_file)
await write_str(infile, index_file)
await write_str_array(infile, x_contigs)
await write_str_array(infile, y_contigs)
await write_str_array(infile, mt_contigs)
await write_str_array(infile, par)
_, resp, _ = await self._rpc('from_fasta_file(...)', inputs, progress=progress)
return orjson.loads(resp)

def load_references_from_dataset(self, path):
return async_to_blocking(self._async_load_references_from_dataset(path))
Expand All @@ -515,22 +541,20 @@ async def inputs(infile, _):
_, resp, _ = await self._rpc('load_references_from_dataset(...)', inputs, progress=progress)
return orjson.loads(resp)

def add_sequence(self, name, fasta_file, index_file):
raise NotImplementedError("ServiceBackend does not support 'add_sequence'")
# Sequence and liftover information is stored on the ReferenceGenome
# and there is no persistent backend to keep in sync so these are a no-op.
# Sequence and liftover information are passed on RPC
def add_sequence(self, name, fasta_file, index_file): # pylint: disable=unused-argument
pass

def remove_sequence(self, name):
raise NotImplementedError("ServiceBackend does not support 'remove_sequence'")
def remove_sequence(self, name): # pylint: disable=unused-argument
pass

def add_liftover(self, name: str, chain_file: str, dest_reference_genome: str):
if name == dest_reference_genome:
raise ValueError(f'Destination reference genome cannot have the same name as this reference {name}.')
if dest_reference_genome in self._liftovers[name]:
raise ValueError(f'Chain file already exists for destination reference {dest_reference_genome}.')
self._liftovers[name][dest_reference_genome] = chain_file
def add_liftover(self, name: str, chain_file: str, dest_reference_genome: str): # pylint: disable=unused-argument
pass

def remove_liftover(self, name, dest_reference_genome):
assert dest_reference_genome in self._liftovers[name]
del self._liftovers[name][dest_reference_genome]
def remove_liftover(self, name, dest_reference_genome): # pylint: disable=unused-argument
pass

def parse_vcf_metadata(self, path):
return async_to_blocking(self._async_parse_vcf_metadata(path))
Expand Down
6 changes: 4 additions & 2 deletions hail/python/hail/genetics/reference_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ def from_fasta_file(cls, name, fasta_file, index_file,
par_strings = ["{}:{}-{}".format(contig, start, end) for (contig, start, end) in par]
config = Env.backend().from_fasta_file(name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par_strings)

rg = ReferenceGenome._from_config(config, _builtin=True)
rg._sequence_files = (fasta_file, index_file)
rg = ReferenceGenome._from_config(config)
rg.add_sequence(fasta_file, index_file)
return rg

@typecheck_method(dest_reference_genome=reference_genome_type)
Expand Down Expand Up @@ -497,6 +497,8 @@ def add_liftover(self, chain_file, dest_reference_genome):
Env.backend().add_liftover(self.name, chain_file, dest_reference_genome.name)
if dest_reference_genome.name in self._liftovers:
raise KeyError(f"Liftover already exists from {self.name} to {dest_reference_genome.name}.")
if dest_reference_genome.name == self.name:
raise ValueError(f'Destination reference genome cannot have the same name as this reference {self.name}.')
self._liftovers[dest_reference_genome.name] = chain_file


Expand Down
1 change: 0 additions & 1 deletion hail/python/test/hail/genetics/test_reference_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def test_reference_genome():
with hl.TemporaryFilename() as filename:
gr2.write(filename)

@fails_service_backend()
def test_reference_genome_sequence():
gr3 = ReferenceGenome.read(resource("fake_ref_genome.json"))
assert gr3.name == "my_reference_genome"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ class LocalBackend(
withExecuteContext(timer) { ctx =>
val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs.asScala.toArray, yContigs.asScala.toArray, mtContigs.asScala.toArray, parInput.asScala.toArray)
addReference(rg)
rg.toJSONString
}
}
Expand Down
70 changes: 63 additions & 7 deletions hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,20 @@ class ServiceBackend(
info(s"Number of BGEN files indexed: ${ files.size }")
"null"
}

def fromFASTAFile(
ctx: ExecuteContext,
name: String,
fastaFile: String,
indexFile: String,
xContigs: Array[String],
yContigs: Array[String],
mtContigs: Array[String],
parInput: Array[String]
): String = {
val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, xContigs, yContigs, mtContigs, parInput)
rg.toJSONString
}
}

class EndOfInputException extends RuntimeException
Expand Down Expand Up @@ -462,9 +476,12 @@ class ServiceBackendSocketAPI2(
private[this] val PARSE_VCF_METADATA = 7
private[this] val INDEX_BGEN = 8
private[this] val IMPORT_FAM = 9
private[this] val FROM_FASTA_FILE = 10

private[this] val dummy = new Array[Byte](8)

private[this] val log = Logger.getLogger(getClass.getName())

def read(bytes: Array[Byte], off: Int, n: Int): Unit = {
assert(off + n <= bytes.length)
var read = 0
Expand Down Expand Up @@ -502,6 +519,17 @@ class ServiceBackendSocketAPI2(

def readString(): String = new String(readBytes(), StandardCharsets.UTF_8)

def readStringArray(): Array[String] = {
val n = readInt()
val arr = new Array[String](n)
var i = 0
while (i < n) {
arr(i) = readString()
i += 1
}
arr
}

def writeBool(b: Boolean): Unit = {
out.write(if (b) 1 else 0)
}
Expand Down Expand Up @@ -554,6 +582,16 @@ class ServiceBackendSocketAPI2(
}
i += 1
}
val nAddedSequences = readInt()
val addedSequences = mutable.Map[String, (String, String)]()
i = 0
while (i < nAddedSequences) {
val rgName = readString()
val fastaFile = readString()
val indexFile = readString()
addedSequences(rgName) = (fastaFile, indexFile)
i += 1
}
val workerCores = readString()
val workerMemory = readString()

Expand Down Expand Up @@ -582,6 +620,9 @@ class ServiceBackendSocketAPI2(
ctx.getReference(sourceGenome).addLiftover(ctx, chainFile, destGenome)
}
}
addedSequences.foreach { case (rg, (fastaFile, indexFile)) =>
ctx.getReference(rg).addSequence(ctx, fastaFile, indexFile)
}
ctx.backendContext = new ServiceBackendContext(sessionId, billingProject, remoteTmpDir, workerCores, workerMemory)
method(ctx)
}
Expand Down Expand Up @@ -647,13 +688,7 @@ class ServiceBackendSocketAPI2(
backend.importFam(_, path, quantPheno, delimiter, missing).getBytes(StandardCharsets.UTF_8)
)
case INDEX_BGEN =>
val nFiles = readInt()
val files = new Array[String](nFiles)
var i = 0
while (i < nFiles) {
files(i) = readString()
i += 1
}
val files = readStringArray()
val nIndexFiles = readInt()
val indexFileMap = mutable.Map[String, String]()
i = 0
Expand Down Expand Up @@ -689,6 +724,27 @@ class ServiceBackendSocketAPI2(
skipInvalidLoci
).getBytes(StandardCharsets.UTF_8)
)
case FROM_FASTA_FILE =>
val name = readString()
val fastaFile = readString()
val indexFile = readString()
val xContigs = readStringArray()
val yContigs = readStringArray()
val mtContigs = readStringArray()
val parInput = readStringArray()
withExecuteContext(
"ServiceBackend.fromFASTAFile",
backend.fromFASTAFile(
_,
name,
fastaFile,
indexFile,
xContigs,
yContigs,
mtContigs,
parInput
).getBytes(StandardCharsets.UTF_8)
)
}
writeBool(true)
writeBytes(result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,6 @@ class SparkBackend(
withExecuteContext(timer) { ctx =>
val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs.asScala.toArray, yContigs.asScala.toArray, mtContigs.asScala.toArray, parInput.asScala.toArray)
addReference(rg)
rg.toJSONString
}
}
Expand Down
20 changes: 8 additions & 12 deletions hail/src/main/scala/is/hail/variant/ReferenceGenome.scala
Original file line number Diff line number Diff line change
Expand Up @@ -331,18 +331,18 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St

val tmpdir = ctx.localTmpdir
val fs = ctx.fs

// If this exists check passes the user is considered to have access to these files
// and is permitted to use cached versions corresponding to these URIs
if (!fs.exists(fastaFile))
fatal(s"FASTA file '$fastaFile' does not exist.")
fatal(s"FASTA file '$fastaFile' does not exist or you do not have access.")
if (!fs.exists(indexFile))
fatal(s"FASTA index file '$indexFile' does not exist.")
fatal(s"FASTA index file '$indexFile' does not exist or you do not have access.")
fastaFilePath = fastaFile
fastaIndexPath = indexFile

// assumption, fastaFile and indexFile will not move or change for the entire duration of a hail pipeline
val localIndexFile = ExecuteContext.createTmpPathNoCleanup(tmpdir, "fasta-reader-add-seq", "fai")
fs.copyRecode(indexFile, localIndexFile)

val index = new FastaSequenceIndex(new java.io.File(uriPath(localIndexFile)))
val index = using(fs.open(indexFile))(new FastaSequenceIndex(_))

val missingContigs = contigs.filterNot(index.hasIndexEntry)
if (missingContigs.nonEmpty)
Expand Down Expand Up @@ -578,9 +578,7 @@ object ReferenceGenome {
if (!fs.exists(indexFile))
fatal(s"FASTA index file '$indexFile' does not exist.")

val localIndexFile = ExecuteContext.createTmpPathNoCleanup(tmpdir, "fasta-reader-from-fasta", "fai")
fs.copyRecode(indexFile, localIndexFile)
val index = new FastaSequenceIndex(new java.io.File(uriPath(localIndexFile)))
val index = using(fs.open(indexFile))(new FastaSequenceIndex(_))

val contigs = new BoxedArrayBuilder[String]
val lengths = new BoxedArrayBuilder[(String, Int)]
Expand All @@ -592,9 +590,7 @@ object ReferenceGenome {
lengths += (contig -> length.toInt)
}

val rg = ReferenceGenome(name, contigs.result(), lengths.result().toMap, xContigs, yContigs, mtContigs, parInput)
rg.addSequence(ctx, fastaFile, indexFile)
rg
ReferenceGenome(name, contigs.result(), lengths.result().toMap, xContigs, yContigs, mtContigs, parInput)
}

def readReferences(fs: FS, path: String): Array[ReferenceGenome] = {
Expand Down

0 comments on commit 8d89cdc

Please sign in to comment.