Skip to content

Commit

Permalink
feat(call): optional incorporation of HP tags from haplotagged alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlougheed committed Feb 1, 2024
1 parent b2fa668 commit 71c2d7f
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 42 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ If more than one read file is specified, the reads will be pooled. This can come
have e.g. multiple flow cells of the same sample split into different BAM files, or the reads are
split by chromosome.

If you want to **incorporate haplotagging from an alignment file (`HP` tags)** into the
process, which should speed up runtime and potentially improve calling results, you must pass
the `--use-hp` flag. **This flag is experimental, and has not been tested extensively.**

If you want to **incorporate SNV calling** into the process, which speeds up runtime and gives
marginally better calling results, you must provide an indexed, `bgzip`-compressed SNV catalog
VCF which matches your reference genome. You can find dbSNP VCFs at
Expand Down
4 changes: 3 additions & 1 deletion docs/caller_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
or small indels.
* `--hq`: Whether to treat provided reads as "high quality", i.e., fairly close to the actual true sequence. Used when
detecting expansions, to skip a smoothing filter that may ignore disparate, rare expansion-like read counts.
Use for CCS reads or similar ONLY! **Default:** off
Use for CCS reads or similar data (e.g., accurate nanopore sequences) ONLY! **Default:** off
* `--use-hp`: Whether to incorporate `HP` tags from a haplotagged alignment file. This should speed up runtime and
will potentially improve calling results. **This flag is experimental, and has not been tested extensively.**
* `--incorporate-snvs [path]` or `--snv [path]`: A path to a VCF with SNVs to incorporate into the calling process and
final output. This file is just used as an SNV loci catalog; STRkit itself will perform the SNV calling. Empirically
improves calling quality a small amount, speeds up runtime, and gives nearby SNV calls for downstream analysis.
Expand Down
195 changes: 157 additions & 38 deletions strkit/call/call_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,90 @@ def process_read_snvs_for_locus(
return locus_snvs


def call_alleles_with_haplotags(
num_bootstrap: int,
min_allele_reads: int,
hq: bool,
fractional: bool,
haplotags: list[str],
read_dict_items: tuple[tuple[str, ReadDict], ...], # We could derive this again, but we already have before...
rng: np.random.Generator,
logger_: logging.Logger,
locus_log_str: str,
) -> Optional[dict]:
n_alleles: int = len(haplotags)

hp_reads: list[tuple[ReadDict, ...]] = []
cns: Union[list[list[int]], list[list[float]]] = []
c_ws: list[Union[NDArray[np.int_], NDArray[np.float_]]] = []

for hi, hp in enumerate(haplotags):
# Find reads for cluster
crs: tuple[ReadDict, ...] = tuple(r for i, (_, r) in enumerate(read_dict_items) if r.get("hp") == hp)

# Calculate copy number set
cns.append(list(map(cn_getter, crs)))

# Calculate weights array
ws = np.fromiter(map(weight_getter, crs), dtype=np.float_)
c_ws.append(ws / np.sum(ws))

hp_reads.append(crs)

cdd: list[CallDict] = []

for hi, hp in enumerate(haplotags):
cc: Optional[CallDict] = call_alleles(
cns[hi], (), # Don't bother separating by strand for now...
c_ws[hi], (),
bootstrap_iterations=num_bootstrap,
min_reads=min_allele_reads, # Calling alleles separately, so set min_reads=min_allele_reads
min_allele_reads=min_allele_reads,
n_alleles=1, # Calling alleles separately: they were pre-separated by agglom. clustering
separate_strands=False,
read_bias_corr_min=0, # separate_strands is false, so this is ignored
gm_filter_factor=1, # n_alleles=1, so this is ignored
hq=hq,
force_int=not fractional,
seed=get_new_seed(rng),
logger_=logger_,
debug_str=f"{locus_log_str} a{hi}"
)

if cc is None: # Early escape
return None

# TODO: set peak weight [0] to the sum of read weights - we normalize this later, but this way
# call dicts with more reads will GET MORE WEIGHT! as it should be, instead of 50/50 for the peak.

cdd.append(cc)

# TODO: Multi-allele phasing across STRs

for i in range(len(haplotags)): # Cluster indices now referring to ordered ones
for rd in hp_reads[i]:
rd["p"] = i

peak_weights_pre_adj = np.concatenate(tuple(cc["peak_weights"] for cc in cdd), axis=0)

# All call_datas are truth-y; all arrays should be ordered by peak_order
call_data = {
"call": np.concatenate(tuple(cc["call"] for cc in cdd), axis=0),
"call_95_cis": np.concatenate(tuple(cc["call_95_cis"] for cc in cdd), axis=0),
"call_99_cis": np.concatenate(tuple(cc["call_99_cis"] for cc in cdd), axis=0),
"peaks": np.concatenate(tuple(cc["peaks"] for cc in cdd), axis=None),

# TODO: Readjust peak weights when combining or don't include
# Make peak weights sum to 1
"peak_weights": peak_weights_pre_adj / np.sum(peak_weights_pre_adj),

"peak_stdevs": np.concatenate(tuple(cc["peak_stdevs"] for cc in cdd), axis=0),
"modal_n_peaks": n_alleles, # n. of alleles = n. of peaks always -- if we phased using SNVs
}

return call_data


def call_alleles_with_incorporated_snvs(
n_alleles: int,
num_bootstrap: int,
Expand Down Expand Up @@ -568,6 +652,7 @@ def call_locus(
sample_id: Optional[str] = None,
realign: bool = False,
hq: bool = False,
use_hp: bool = False,
# incorporate_snvs: bool = False,
snv_vcf_file: Optional[pysam.VariantFile] = None,
snv_vcf_contigs: tuple[str, ...] = (),
Expand Down Expand Up @@ -683,6 +768,8 @@ def call_locus(
read_dict: dict[str, ReadDict] = {}
read_dict_extra: dict[str, ReadDictExtra] = {}
realign_count: int = 0 # Number of realigned reads
haplotagged_reads_count: int = 0 # Number of reads with HP tags
haplotags: set[str] = set()

# Aggregations for additional read-level data
read_kmers: Counter[str] = Counter()
Expand Down Expand Up @@ -714,7 +801,7 @@ def call_locus(
# Soft-clipping in large insertions can result from mapping difficulties.
# If we have a soft clip which overlaps with our TR region (+ flank), we can try to recover it
# via realignment with parasail.
# 4: BAM code for soft clip
# 4: BAM code for soft clip CIGAR operation
# TODO: if some alignment is present, use it to reduce realignment overhead?
# - use start point + flank*3 or end point - flank*3 or something like that
if realign and (force_realign or (
Expand Down Expand Up @@ -855,6 +942,13 @@ def call_locus(

# Reads can show up more than once - TODO - cache this information across loci

if use_hp:
tags = dict(segment.get_tags())
if (hp := tags.get("HP")) is not None:
read_dict[rn]["hp"] = hp
haplotags.add(hp)
haplotagged_reads_count += 1

if should_incorporate_snvs:
# Store the segment sequence in the read dict for the next go-around if we've enabled SNV incorporation,
# in order to pass the query sequence to the get_read_snvs function with the cached ref string.
Expand Down Expand Up @@ -911,10 +1005,11 @@ def call_locus(
# noinspection PyTypeChecker
read_dict_items: tuple[tuple[str, ReadDict], ...] = tuple(read_dict.items())

assign_method: Literal["dist", "snv", "snv+dist", "single"] = "dist"
assign_method: Literal["dist", "snv", "snv+dist", "single", "hp"] = "dist"
if n_alleles < 2:
assign_method = "single"

min_hp_read_coverage: int = 8 # TODO: parametrize
min_snv_read_coverage: int = 8 # TODO: parametrize

# Realigns are missing significant amounts of flanking information since the realignment only uses a portion of the
Expand All @@ -927,51 +1022,74 @@ def call_locus(
have_rare_realigns = True
break

if realign_count >= many_realigns_threshold or have_rare_realigns:
logger_.warning(
f"{locus_log_str} - cannot use SNVs; one of {realign_count=} >= {many_realigns_threshold} or "
f"{have_rare_realigns=}")

elif should_incorporate_snvs and n_reads_in_dict >= min_snv_read_coverage and not have_rare_realigns:
# LIMITATION: Currently can only use SNVs for haplotyping with haploid/diploid

# Second read loop occurs in this function
locus_snvs: set[int] = process_read_snvs_for_locus(
contig, left_coord_adj, right_coord_adj, left_most_coord, right_most_coord, ref, read_dict_items,
read_dict_extra, read_pairs, candidate_snvs_dict, only_known_snvs, logger_, locus_log_str)

useful_snvs: list[tuple[int, int]] = calculate_useful_snvs(
n_reads_in_dict, read_dict_items, read_dict_extra, read_pairs, locus_snvs, min_allele_reads)
n_useful_snvs: int = len(useful_snvs)

if not n_useful_snvs:
logger_.debug(f"{locus_log_str} - no useful SNVs")
else:
am, call_res = call_alleles_with_incorporated_snvs(
n_alleles=n_alleles,
if use_hp:
if haplotagged_reads_count >= min_hp_read_coverage and len(haplotags) == n_alleles:
hp_sorted = sorted(haplotags)
call_res = call_alleles_with_haplotags(
num_bootstrap=num_bootstrap,
min_allele_reads=min_allele_reads,
hq=hq,
fractional=fractional,
read_dict=read_dict,
haplotags=hp_sorted,
read_dict_items=read_dict_items,
read_dict_extra=read_dict_extra,
n_reads_in_dict=n_reads_in_dict,
useful_snvs=useful_snvs,
candidate_snvs_dict=candidate_snvs_dict,
rng=rng,
logger_=logger_,
locus_log_str=locus_log_str,
)
assign_method = am
if call_res is not None:
call_data = call_res[0] # Call data dictionary
call_dict_base["snvs"] = call_res[1] # Called useful SNVs

elif n_reads_in_dict < min_snv_read_coverage:
logger_.debug(
f"{locus_log_str} - not enough coverage for SNV incorporation "
f"({n_reads_in_dict} < {min_snv_read_coverage})")
assign_method = "hp"
call_data = call_res
else:
logger_.debug(
f"{locus_log_str} - Not enough HP tags for incorporation; one of {haplotagged_reads_count} < "
f"{min_hp_read_coverage} or {len(haplotags)} != {n_alleles}")

if should_incorporate_snvs and assign_method != "hp":
if realign_count >= many_realigns_threshold or have_rare_realigns:
logger_.warning(
f"{locus_log_str} - cannot use SNVs; one of {realign_count=} >= {many_realigns_threshold} or "
f"{have_rare_realigns=}")

elif n_reads_in_dict >= min_snv_read_coverage and not have_rare_realigns:
# LIMITATION: Currently can only use SNVs for haplotyping with haploid/diploid

# Second read loop occurs in this function
locus_snvs: set[int] = process_read_snvs_for_locus(
contig, left_coord_adj, right_coord_adj, left_most_coord, right_most_coord, ref, read_dict_items,
read_dict_extra, read_pairs, candidate_snvs_dict, only_known_snvs, logger_, locus_log_str)

useful_snvs: list[tuple[int, int]] = calculate_useful_snvs(
n_reads_in_dict, read_dict_items, read_dict_extra, read_pairs, locus_snvs, min_allele_reads)
n_useful_snvs: int = len(useful_snvs)

if not n_useful_snvs:
logger_.debug(f"{locus_log_str} - no useful SNVs")
else:
am, call_res = call_alleles_with_incorporated_snvs(
n_alleles=n_alleles,
num_bootstrap=num_bootstrap,
min_allele_reads=min_allele_reads,
hq=hq,
fractional=fractional,
read_dict=read_dict,
read_dict_items=read_dict_items,
read_dict_extra=read_dict_extra,
n_reads_in_dict=n_reads_in_dict,
useful_snvs=useful_snvs,
candidate_snvs_dict=candidate_snvs_dict,
rng=rng,
logger_=logger_,
locus_log_str=locus_log_str,
)
assign_method = am
if call_res is not None:
call_data = call_res[0] # Call data dictionary
call_dict_base["snvs"] = call_res[1] # Called useful SNVs

elif n_reads_in_dict < min_snv_read_coverage:
logger_.debug(
f"{locus_log_str} - not enough coverage for SNV incorporation "
f"({n_reads_in_dict} < {min_snv_read_coverage})")

single_or_dist_assign: bool = assign_method in ("single", "dist")

Expand Down Expand Up @@ -1102,7 +1220,8 @@ def _nested_ndarray_serialize(x: Iterable) -> list[list[Union[int, float, np.int
call_val = apply_or_none(_ndarray_serialize, call)
call_95_cis_val = apply_or_none(_nested_ndarray_serialize, call_95_cis)

logger_.debug(f"{locus_log_str} - got call: {call_val} (95% CIs: {call_95_cis_val})")
logger_.debug(
f"{locus_log_str} - got call: {call_val} (95% CIs: {call_95_cis_val}); peak assign method={assign_method}")

return {
**call_dict_base,
Expand Down
8 changes: 6 additions & 2 deletions strkit/call/call_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def locus_worker(
sample_id: Optional[str],
realign: bool,
hq: bool,
use_hp: bool,
# incorporate_snvs: bool,
snv_vcf: Optional[str],
targeted: bool,
Expand Down Expand Up @@ -112,6 +113,7 @@ def locus_worker(
sample_id=sample_id,
realign=realign,
hq=hq,
use_hp=use_hp,
# incorporate_snvs=incorporate_snvs,
snv_vcf_file=snv_vcf_file,
snv_vcf_contigs=tuple(snv_vcf_contigs),
Expand Down Expand Up @@ -200,6 +202,7 @@ def call_sample(
sex_chroms: Optional[str] = None,
realign: bool = False,
hq: bool = False,
use_hp: bool = False,
# incorporate_snvs: bool = False,
snv_vcf: Optional[pathlib.Path] = None,
targeted: bool = False,
Expand Down Expand Up @@ -241,8 +244,8 @@ def call_sample(
sample_id_final: Optional[str] = sample_id or bam_sample_id

logger.info(
f"Starting STR genotyping; sample={sample_id_final}, hq={hq}, targeted={targeted}, SNVs={snv_vcf is not None}; "
f"seed={seed}")
f"Starting STR genotyping; sample={sample_id_final}, hq={hq}, targeted={targeted}, HP={use_hp}, "
f"SNVs={snv_vcf is not None}; seed={seed}")

# Seed the random number generator if a seed is provided, for replicability
rng: np.random.Generator = np.random.default_rng(seed=seed)
Expand Down Expand Up @@ -286,6 +289,7 @@ def call_sample(
"sample_id": sample_id_final,
"realign": realign,
"hq": hq,
"use_hp": use_hp,
# "incorporate_snvs": incorporate_snvs,
"snv_vcf": str(snv_vcf) if snv_vcf else None,
"targeted": targeted,
Expand Down
5 changes: 4 additions & 1 deletion strkit/call/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ class ReadDict(_ReadDictBase, total=False):

kmers: dict[str, int] # Dictionary of {kmer: count}

# Below are only added if SNVs are being incorporated:
# Only added if HP tags from a haplotagged alignment file are being incorporated:
hp: str

# Only added if SNVs are being incorporated:
snvu: tuple[str, ...] # After including only useful SNVs, this contains a tuple of bases for just those


Expand Down
6 changes: 6 additions & 0 deletions strkit/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def add_call_parser_args(call_parser):
help="Whether to treat provided reads as 'fairly' accurate, i.e. not liable to be extremely far away from the "
"DNA truth. Recommended for CCS, and CCS ONLY!")

call_parser.add_argument(
"--use-hp",
action="store_true",
help="Whether to use HP tags from the alignment file (i.e., a haplotagged alignment file), if available.")

call_parser.add_argument(
"--incorporate-snvs", "--snv", "-v",
type=pathlib.Path,
Expand Down Expand Up @@ -364,6 +369,7 @@ def _exec_call(p_args) -> None:
sex_chroms=p_args.sex_chr,
realign=p_args.realign,
hq=p_args.hq,
use_hp=p_args.use_hp,
# incorporate_snvs=p_args.incorporate_snvs,
snv_vcf=p_args.incorporate_snvs,
targeted=p_args.targeted,
Expand Down

0 comments on commit 71c2d7f

Please sign in to comment.