Skip to content

Commit

Permalink
refact(call): move call_alleles_with_gmm procedure to function
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlougheed committed Jun 12, 2024
1 parent ae28967 commit 9e6a0bf
Showing 1 changed file with 39 additions and 24 deletions.
63 changes: 39 additions & 24 deletions strkit/call/call_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,48 @@ def calculate_read_distance(
return distance_matrix


def call_alleles_with_gmm(
params: CallParams,
n_alleles: int,
read_dict: dict[str, ReadDict],
assign_method: str,
# ---
rng: np.random.Generator,
# ---
logger_: logging.Logger,
locus_log_str: str,
) -> CallDict | dict:
# Dicts are ordered in Python; very nice :)
rdvs = tuple(read_dict.values())
read_cns = np.fromiter(map(cn_getter, rdvs), dtype=np.int_)
read_weights = np.fromiter(map(weight_getter, rdvs), dtype=np.float_)
read_weights /= read_weights.sum() # Normalize to probabilities

logger_.debug(f"{locus_log_str} - assigning alleles using {assign_method} method with {read_cns.shape[0]} reads")

return call_alleles(
read_cns, (),
read_weights, (),
params=params,
min_reads=params.min_reads,
n_alleles=n_alleles,
separate_strands=False,
read_bias_corr_min=0, # TODO: parametrize
gm_filter_factor=3, # TODO: parametrize
seed=get_new_seed(rng),
logger_=logger_,
debug_str=locus_log_str,
) or {} # Still false-y


def call_alleles_with_haplotags(
params: CallParams,
haplotags: list[int],
ps_id: int,
read_dict_items: tuple[tuple[str, ReadDict], ...], # We could derive this again, but we already have before...
# ---
rng: np.random.Generator,
# ---
logger_: logging.Logger,
locus_log_str: str,
) -> Optional[dict]:
Expand All @@ -266,7 +302,7 @@ def call_alleles_with_haplotags(

# Calculate weights array
ws = np.fromiter(map(weight_getter, crs), dtype=np.float_)
c_ws.append(ws / np.sum(ws))
c_ws.append(ws / ws.sum())

hp_reads.append(crs)

Expand Down Expand Up @@ -710,7 +746,7 @@ def call_alleles_with_incorporated_snvs(

# TODO: Readjust peak weights when combining or don't include
# Make peak weights sum to 1
"peak_weights": peak_weights_pre_adj / np.sum(peak_weights_pre_adj),
"peak_weights": peak_weights_pre_adj / peak_weights_pre_adj.sum(),

"peak_stdevs": np.concatenate(tuple(cc["peak_stdevs"] for cc in cdd_ordered), axis=0),
"modal_n_peaks": n_alleles, # n. of alleles = n. of peaks always -- if we phased using SNVs
Expand Down Expand Up @@ -1382,28 +1418,7 @@ def call_locus(
single_or_dist_assign: bool = assign_method in ("single", "dist")

if single_or_dist_assign: # Didn't use SNVs, so call the 'old-fashioned' way - using only copy number
# Dicts are ordered in Python; very nice :)
rdvs = tuple(read_dict.values())
rcns = tuple(map(cn_getter, rdvs))
read_cns = np.fromiter(rcns, dtype=np.int_)
read_weights = np.fromiter(map(weight_getter, rdvs), dtype=np.float_)
read_weights = read_weights / np.sum(read_weights) # Normalize to probabilities

logger_.debug(f"{locus_log_str} - assigning alleles using {assign_method} method with {len(rcns)} reads")

call_data = call_alleles(
read_cns, (),
read_weights, (),
params=params,
min_reads=params.min_reads,
n_alleles=n_alleles,
separate_strands=False,
read_bias_corr_min=0, # TODO: parametrize
gm_filter_factor=3, # TODO: parametrize
seed=get_new_seed(rng),
logger_=logger_,
debug_str=locus_log_str,
) or {} # Still false-y
call_data = call_alleles_with_gmm(params, n_alleles, read_dict, assign_method, rng, logger_, locus_log_str)

allele_time = (datetime.now() - allele_start_time).total_seconds()

Expand Down

0 comments on commit 9e6a0bf

Please sign in to comment.