Skip to content

Commit

Permalink
Merge pull request #82 from hall-lab/max_ci_dist
Browse files Browse the repository at this point in the history
Stub in code for toggling use of 95 pct confidence interval
  • Loading branch information
ernfrid authored Sep 13, 2018
2 parents 19aa65f + 3afc778 commit 070e78b
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 27 deletions.
17 changes: 10 additions & 7 deletions svtyper/classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import svtyper.version

from svtyper.parsers import Vcf, Variant, Sample
from svtyper.parsers import Vcf, Variant, Sample, confidence_interval
from svtyper.utils import *
from svtyper.statistics import bayes_gt

Expand All @@ -32,6 +32,7 @@ def get_args():
parser.add_argument('-n', dest='num_samp', metavar='INT', type=int, required=False, default=1000000, help='number of reads to sample from BAM file for building insert size distribution [1000000]')
parser.add_argument('-q', '--sum_quals', action='store_true', required=False, help='add genotyping quality to existing QUAL (default: overwrite QUAL field)')
parser.add_argument('--max_reads', metavar='INT', type=int, default=None, required=False, help='maximum number of reads to assess at any variant (reduces processing time in high-depth regions, default: unlimited)')
parser.add_argument('--max_ci_dist', metavar='INT', type=int, default=1e10, required=False, help='maximum size of a confidence interval before 95% CI is used intead (default: 1e10)')
parser.add_argument('--split_weight', metavar='FLOAT', type=float, required=False, default=1, help='weight for split reads [1]')
parser.add_argument('--disc_weight', metavar='FLOAT', type=float, required=False, default=1, help='weight for discordant paired-end reads [1]')
parser.add_argument('-w', '--write_alignment', metavar='FILE', dest='alignment_outpath', type=str, required=False, default=None, help='write relevant reads to BAM file')
Expand Down Expand Up @@ -115,7 +116,8 @@ def sv_genotype(bam_string,
alignment_outpath,
ref_fasta,
sum_quals,
max_reads):
max_reads,
max_ci_dist):

# parse the comma separated inputs
bam_list = []
Expand Down Expand Up @@ -236,8 +238,8 @@ def sv_genotype(bam_string,
posA = var.pos
posB = var2.pos
# confidence intervals
ciA = map(int, var.info['CIPOS'].split(','))
ciB = map(int, var2.info['CIPOS'].split(','))
ciA = confidence_interval(var, 'CIPOS', 'CIPOS95', max_ci_dist)
ciB = confidence_interval(var2, 'CIPOS', 'CIPOS95', max_ci_dist)

# infer the strands from the alt allele
if var.alt[-1] == '[' or var.alt[-1] == ']':
Expand All @@ -259,8 +261,8 @@ def sv_genotype(bam_string,
posA = var.pos
posB = int(var.get_info('END'))
# confidence intervals
ciA = map(int, var.info['CIPOS'].split(','))
ciB = map(int, var.info['CIEND'].split(','))
ciA = confidence_interval(var, 'CIPOS', 'CIPOS95', max_ci_dist)
ciB = confidence_interval(var, 'CIEND', 'CIEND95', max_ci_dist)
if svtype == 'DEL':
var_length = posB - posA
o1_is_reverse, o2_is_reverse = False, True
Expand Down Expand Up @@ -561,7 +563,8 @@ def main():
args.alignment_outpath,
args.ref_fasta,
args.sum_quals,
args.max_reads)
args.max_reads,
args.max_ci_dist)

# --------------------------------------
# command-line/console entrypoint
Expand Down
24 changes: 15 additions & 9 deletions svtyper/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
# ==================================================
# VCF parsing tools
# ==================================================
def confidence_interval(var, tag, alt_tag, max_ci_dist):
ci = map(int, var.info[tag].split(','))
if ci[1] - ci[0] > max_ci_dist:
return map(int, var.info[alt_tag].split(','))
return ci


class Vcf(object):
def __init__(self):
Expand Down Expand Up @@ -120,7 +126,7 @@ def sample_to_col(self, sample):
def _init_bnd_breakpoint_func():
bnd_cache = {}

def _get_bnd_breakpoints(variant):
def _get_bnd_breakpoints(variant, max_ci_dist):
if variant.info['MATEID'] in bnd_cache:
var2 = variant
var = bnd_cache[variant.info['MATEID']]
Expand All @@ -129,8 +135,8 @@ def _get_bnd_breakpoints(variant):
posA = var.pos
posB = var2.pos
# confidence intervals
ciA = map(int, var.info['CIPOS'].split(','))
ciB = map(int, var2.info['CIPOS'].split(','))
ciA = confidence_interval(var, 'CIPOS', 'CIPOS95', max_ci_dist)
ciB = confidence_interval(var2, 'CIPOS', 'CIPOS95', max_ci_dist)

# infer the strands from the alt allele
if var.alt[-1] == '[' or var.alt[-1] == ']':
Expand Down Expand Up @@ -163,14 +169,14 @@ def _get_bnd_breakpoints(variant):
return _get_bnd_breakpoints

@staticmethod
def _default_get_breakpoints(variant):
def _default_get_breakpoints(variant, max_ci_dist):
chromA = variant.chrom
chromB = variant.chrom
posA = variant.pos
posB = int(variant.get_info('END'))
# confidence intervals
ciA = map(int, variant.info['CIPOS'].split(','))
ciB = map(int, variant.info['CIEND'].split(','))
ciA = confidence_interval(variant, 'CIPOS', 'CIPOS95', max_ci_dist)
ciB = confidence_interval(variant, 'CIEND', 'CIEND95', max_ci_dist)
svtype = variant.get_svtype()
if svtype == 'DEL':
var_length = posB - posA
Expand Down Expand Up @@ -202,17 +208,17 @@ def _default_get_breakpoints(variant):

return breakpoints

def get_variant_breakpoints(self, variant):
def get_variant_breakpoints(self, variant, max_ci_dist):
if self._bnd_breakpoint_func is None:
func = self._init_bnd_breakpoint_func()
self._bnd_breakpoint_func = func

breakpoints = None
if variant.get_svtype() == 'BND':
func = self._bnd_breakpoint_func
breakpoints = func(variant)
breakpoints = func(variant, max_ci_dist)
else:
breakpoints = self._default_get_breakpoints(variant)
breakpoints = self._default_get_breakpoints(variant, max_ci_dist)

return breakpoints

Expand Down
21 changes: 12 additions & 9 deletions svtyper/singlesample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from cytoolz.itertoolz import partition_all

import svtyper.version
from svtyper.parsers import Vcf, Variant, Sample, SamFragment
from svtyper.parsers import Vcf, Variant, Sample, SamFragment, confidence_interval
from svtyper.utils import die, logit, prob_mapq, write_sample_json, tempdir, vcf_headers, vcf_variants, vcf_samples
from svtyper.statistics import bayes_gt

Expand All @@ -28,6 +28,7 @@ def get_args():
parser.add_argument('-n', dest='num_samp', metavar='INT', type=int, required=False, default=1000000, help='number of reads to sample from BAM file for building insert size distribution [1000000]')
parser.add_argument('-q', '--sum_quals', action='store_true', required=False, help='add genotyping quality to existing QUAL (default: overwrite QUAL field)')
parser.add_argument('--max_reads', metavar='INT', type=int, default=1000, required=False, help='maximum number of reads to assess at any variant (reduces processing time in high-depth regions, default: 1000)')
parser.add_argument('--max_ci_dist', metavar='INT', type=int, default=1e10, required=False, help='maximum size of a confidence interval before 95% CI is used intead (default: 1e10)')
parser.add_argument('--split_weight', metavar='FLOAT', type=float, required=False, default=1, help='weight for split reads [1]')
parser.add_argument('--disc_weight', metavar='FLOAT', type=float, required=False, default=1, help='weight for discordant paired-end reads [1]')
parser.add_argument('--debug', action='store_true', help=argparse.SUPPRESS)
Expand Down Expand Up @@ -123,14 +124,14 @@ def init_vcf(vcffile, sample, scratchdir):
v.add_sample(sample.name)
return v

def collect_breakpoints(vcf):
def collect_breakpoints(vcf, max_ci_dist):
breakpoints = []
for vline in vcf_variants(vcf.filename):
v = vline.rstrip().split('\t')
variant = Variant(v, vcf)
if not variant.has_svtype(): continue
if not variant.is_valid_svtype(): continue
brkpts = vcf.get_variant_breakpoints(variant)
brkpts = vcf.get_variant_breakpoints(variant, max_ci_dist)
if brkpts is None: continue
breakpoints.append(brkpts)
return breakpoints
Expand Down Expand Up @@ -573,7 +574,7 @@ def assign_genotype_to_variant(variant, sample, genotype_result):
variant.genotype(sample.name).set_format('AB', outcome['formats']['AB'])
return variant

def genotype_serial(src_vcf, out_vcf, sample, z, split_slop, min_aligned, sum_quals, split_weight, disc_weight, max_reads, debug):
def genotype_serial(src_vcf, out_vcf, sample, z, split_slop, min_aligned, sum_quals, split_weight, disc_weight, max_reads, max_ci_dist, debug):
# initializations
bnd_cache = {}
src_vcf.write_header(out_vcf)
Expand Down Expand Up @@ -607,7 +608,7 @@ def genotype_serial(src_vcf, out_vcf, sample, z, split_slop, min_aligned, sum_qu
variant.write(out_vcf)
continue

breakpoints = src_vcf.get_variant_breakpoints(variant)
breakpoints = src_vcf.get_variant_breakpoints(variant, max_ci_dist)

# special BND processing
if variant.get_svtype() == 'BND':
Expand Down Expand Up @@ -706,15 +707,15 @@ def apply_genotypes_to_vcf(src_vcf, out_vcf, genotypes, sample, sum_quals):
variant2.genotype = variant.genotype
variant2.write(out_vcf)

def genotype_parallel(src_vcf, out_vcf, sample, z, split_slop, min_aligned, sum_quals, split_weight, disc_weight, max_reads, debug, cores, breakpoint_batch_size, ref_fasta):
def genotype_parallel(src_vcf, out_vcf, sample, z, split_slop, min_aligned, sum_quals, split_weight, disc_weight, max_reads, max_ci_dist, debug, cores, breakpoint_batch_size, ref_fasta):

# cleanup unused library attributes
for rg in sample.rg_to_lib:
sample.rg_to_lib[rg].cleanup()

# 1st pass through input vcf -- collect all the relevant breakpoints
logit("Collecting breakpoints")
breakpoints = collect_breakpoints(src_vcf)
breakpoints = collect_breakpoints(src_vcf, max_ci_dist)
logit("Number of breakpoints/SVs to process: {}".format(len(breakpoints)))
logit("Collecting regions")
regions = [ get_breakpoint_regions(b, sample, z) for b in breakpoints ]
Expand Down Expand Up @@ -772,6 +773,7 @@ def sso_genotype(bam_string,
ref_fasta,
sum_quals,
max_reads,
max_ci_dist,
cores,
batch_size):

Expand Down Expand Up @@ -802,11 +804,11 @@ def sso_genotype(bam_string,
if cores is None:
logit("Genotyping Input VCF (Serial Mode)")
# pass through input vcf -- perform actual genotyping
genotype_serial(src_vcf, vcf_out, sample, z, split_slop, min_aligned, sum_quals, split_weight, disc_weight, max_reads, debug)
genotype_serial(src_vcf, vcf_out, sample, z, split_slop, min_aligned, sum_quals, split_weight, disc_weight, max_reads, max_ci_dist, debug)
else:
logit("Genotyping Input VCF (Parallel Mode)")

genotype_parallel(src_vcf, vcf_out, sample, z, split_slop, min_aligned, sum_quals, split_weight, disc_weight, max_reads, debug, cores, batch_size, ref_fasta)
genotype_parallel(src_vcf, vcf_out, sample, z, split_slop, min_aligned, sum_quals, split_weight, disc_weight, max_reads, max_ci_dist, debug, cores, batch_size, ref_fasta)


sample.close()
Expand Down Expand Up @@ -834,6 +836,7 @@ def main():
args.ref_fasta,
args.sum_quals,
args.max_reads,
args.max_ci_dist,
args.cores,
args.batch_size)

Expand Down
2 changes: 1 addition & 1 deletion svtyper/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__author__ = "Colby Chiang (colbychiang@wustl.edu)"
__version__ = "v0.6.0"
__version__ = "v0.6.2"
2 changes: 2 additions & 0 deletions tests/test_singlesample.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_serial_integration(self):
ref_fasta=None,
sum_quals=False,
max_reads=1000,
max_ci_dist=1e10,
cores=None,
batch_size=1000)

Expand All @@ -56,6 +57,7 @@ def test_parallel_integration(self):
ref_fasta=None,
sum_quals=False,
max_reads=1000,
max_ci_dist=1e10,
cores=1,
batch_size=1000)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_svtyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def test_integration(self):
alignment_outpath=None,
ref_fasta=None,
sum_quals=False,
max_reads=None)
max_reads=None,
max_ci_dist=1e10)

fail_msg = "did not file output vcf '{}' after running sv_genotype".format(out_vcf)
self.assertTrue(os.path.exists(out_vcf), fail_msg)
Expand Down

0 comments on commit 070e78b

Please sign in to comment.