Skip to content

Commit

Permalink
Use overlaps partitioner in HaplotypeCaller
Browse files Browse the repository at this point in the history
Broadcast HCEngine

Make GenotypingEngine and IndexedSet Kryo-serializable
  • Loading branch information
tomwhite committed Jun 21, 2017
1 parent 3ec7399 commit 9df50e7
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package org.broadinstitute.hellbender.tools;


import java.io.Serializable;
import java.util.Iterator;
import org.broadinstitute.hellbender.engine.Shard;
import org.broadinstitute.hellbender.engine.ShardBoundary;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.downsampling.ReadsDownsampler;
import org.broadinstitute.hellbender.utils.downsampling.ReadsDownsamplingIterator;
import org.broadinstitute.hellbender.utils.read.GATKRead;

/**
* A simple shard implementation intended to be used for splitting reads by partition in Spark tools
*/
public final class DownsampleableSparkReadShard implements Shard<GATKRead>, Serializable {
private static final long serialVersionUID = 1L;

private final ShardBoundary boundaries;
private final Iterable<GATKRead> reads;
private ReadsDownsampler downsampler;

public DownsampleableSparkReadShard(final ShardBoundary boundaries, final Iterable<GATKRead> reads){
this.boundaries = Utils.nonNull(boundaries);
this.reads = Utils.nonNull(reads);
}

/**
* Reads in this shard will be downsampled using this downsampler before being returned.
*
* @param downsampler downsampler to use (may be null, which signifies that no downsampling is to be performed)
*/
public void setDownsampler(final ReadsDownsampler downsampler) {
this.downsampler = downsampler;
}

@Override
public SimpleInterval getInterval() {
return boundaries.getInterval();
}

@Override
public SimpleInterval getPaddedInterval() {
return boundaries.getPaddedInterval();
}

@Override
public Iterator<GATKRead> iterator() {
Iterator<GATKRead> readsIterator = reads.iterator();

if ( downsampler != null ) {
readsIterator = new ReadsDownsamplingIterator(readsIterator, downsampler);
}

return readsIterator;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.engine.spark.SparkReadShard;
import org.broadinstitute.hellbender.engine.spark.SparkSharder;
import org.broadinstitute.hellbender.engine.spark.datasources.VariantsSparkSink;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
Expand All @@ -37,6 +38,8 @@
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.downsampling.PositionalDownsampler;
import org.broadinstitute.hellbender.utils.downsampling.ReadsDownsampler;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.reference.ReferenceBases;
import scala.Tuple2;
Expand Down Expand Up @@ -88,6 +91,9 @@ public static class ShardingArgumentCollection implements Serializable {
@Argument(fullName = "assemblyRegionPadding", shortName = "assemblyRegionPadding", doc = "Number of additional bases of context to include around each assembly region", optional = true)
public int assemblyRegionPadding = HaplotypeCaller.DEFAULT_ASSEMBLY_REGION_PADDING;

@Argument(fullName = "maxReadsPerAlignmentStart", shortName = "maxReadsPerAlignmentStart", doc = "Maximum number of reads to retain per alignment start position. Reads above this threshold will be downsampled. Set to 0 to disable.", optional = true)
public int maxReadsPerAlignmentStart = HaplotypeCaller.DEFAULT_MAX_READS_PER_ALIGNMENT;

@Advanced
@Argument(fullName = "activeProbabilityThreshold", shortName = "activeProbabilityThreshold", doc="Minimum probability for a locus to be considered active.", optional = true)
public double activeProbThreshold = HaplotypeCaller.DEFAULT_ACTIVE_PROB_THRESHOLD;
Expand Down Expand Up @@ -115,12 +121,12 @@ public boolean requiresReference(){
@Override
protected void runTool(final JavaSparkContext ctx) {
final List<SimpleInterval> intervals = hasIntervals() ? getIntervals() : IntervalUtils.getAllIntervalsForReference(getHeaderForReads().getSequenceDictionary());
final JavaRDD<VariantContext> variants = callVariantsWithHaplotypeCaller(getAuthHolder(), ctx, getReads(), getHeaderForReads(), getReference(), intervals, hcArgs, shardingArgs);
final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgs, getHeaderForReads(), new ReferenceMultiSourceAdapter(getReference(), getAuthHolder()));
final JavaRDD<VariantContext> variants = callVariantsWithHaplotypeCaller(getAuthHolder(), ctx, getReads(), getHeaderForReads(), getReference(), intervals, hcArgs, shardingArgs, hcEngine);
if (hcArgs.emitReferenceConfidence == ReferenceConfidenceMode.GVCF) {
// VariantsSparkSink/Hadoop-BAM VCFOutputFormat do not support writing GVCF, see https://github.com/broadinstitute/gatk/issues/2738
writeVariants(variants);
writeVariants(variants, hcEngine);
} else {
final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgs, getHeaderForReads(), new ReferenceMultiSourceAdapter(getReference(), getAuthHolder()));
variants.cache(); // without caching, computations are run twice as a side effect of finding partition boundaries for sorting
try {
VariantsSparkSink.writeVariants(ctx, output, variants, hcEngine.makeVCFHeader(getHeaderForReads().getSequenceDictionary(), new HashSet<>()));
Expand Down Expand Up @@ -148,6 +154,7 @@ public List<ReadFilter> getDefaultReadFilters() {
* @param intervals the intervals to restrict calling to
* @param hcArgs haplotype caller arguments
* @param shardingArgs arguments to control how the assembly regions are sharded
* @param hcEngine
* @return an RDD of Variants
*/
public static JavaRDD<VariantContext> callVariantsWithHaplotypeCaller(
Expand All @@ -158,7 +165,8 @@ public static JavaRDD<VariantContext> callVariantsWithHaplotypeCaller(
final ReferenceMultiSource reference,
final List<SimpleInterval> intervals,
final HaplotypeCallerArgumentCollection hcArgs,
final ShardingArgumentCollection shardingArgs) {
final ShardingArgumentCollection shardingArgs,
final HaplotypeCallerEngine hcEngine) {
Utils.validateArg(hcArgs.dbsnp.dbsnp == null, "HaplotypeCallerSpark does not yet support -D or --dbsnp arguments" );
Utils.validateArg(hcArgs.comps.isEmpty(), "HaplotypeCallerSpark does not yet support -comp or --comp arguments" );
Utils.validateArg(hcArgs.bamOutputPath == null, "HaplotypeCallerSpark does not yet support -bamout or --bamOutput");
Expand All @@ -168,15 +176,19 @@ public static JavaRDD<VariantContext> callVariantsWithHaplotypeCaller(

final Broadcast<ReferenceMultiSource> referenceBroadcast = ctx.broadcast(reference);
final Broadcast<HaplotypeCallerArgumentCollection> hcArgsBroadcast = ctx.broadcast(hcArgs);
final OverlapDetector<ShardBoundary> overlaps = getShardBoundaryOverlapDetector(header, intervals, shardingArgs.readShardSize, shardingArgs.readShardPadding);
final Broadcast<OverlapDetector<ShardBoundary>> shardBoundariesBroadcast = ctx.broadcast(overlaps);

final JavaRDD<Shard<GATKRead>> readShards = createReadShards(shardBoundariesBroadcast, reads);
final Broadcast<HaplotypeCallerEngine> hcEngineBroadcast = ctx.broadcast(hcEngine);

final List<ShardBoundary> shardBoundaries = getShardBoundaries(header, intervals, shardingArgs.readShardSize, shardingArgs.readShardPadding);

final int maxLocatableSize = reads.map(r -> r.getEnd() - r.getStart() + 1).reduce(Math::max);

final JavaRDD<Shard<GATKRead>> readShards = createReadShards(ctx, shardBoundaries, reads, header, maxLocatableSize);

final JavaRDD<Tuple2<AssemblyRegion, SimpleInterval>> assemblyRegions = readShards
.mapPartitions(shardsToAssemblyRegions(authHolder, referenceBroadcast, hcArgsBroadcast, shardingArgs, header));
.mapPartitions(shardsToAssemblyRegions(authHolder, referenceBroadcast, hcArgsBroadcast, shardingArgs, header, hcEngineBroadcast));

return assemblyRegions.mapPartitions(callVariantsFromAssemblyRegions(authHolder, header, referenceBroadcast, hcArgsBroadcast));
return assemblyRegions.mapPartitions(callVariantsFromAssemblyRegions(authHolder, header, referenceBroadcast, hcArgsBroadcast, hcEngineBroadcast));
}

/**
Expand All @@ -188,12 +200,10 @@ private static FlatMapFunction<Iterator<Tuple2<AssemblyRegion, SimpleInterval>>,
final AuthHolder authHolder,
final SAMFileHeader header,
final Broadcast<ReferenceMultiSource> referenceBroadcast,
final Broadcast<HaplotypeCallerArgumentCollection> hcArgsBroadcast) {
final Broadcast<HaplotypeCallerArgumentCollection> hcArgsBroadcast,
final Broadcast<HaplotypeCallerEngine> hcEngineBroadcast) {
return regionAndIntervals -> {
//HaplotypeCallerEngine isn't serializable but is expensive to instantiate, so construct and reuse one for every partition
final ReferenceMultiSourceAdapter referenceReader = new ReferenceMultiSourceAdapter(referenceBroadcast.getValue(), authHolder);
final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgsBroadcast.value(), header, referenceReader);
return iteratorToStream(regionAndIntervals).flatMap(regionToVariants(hcEngine)).iterator();
return iteratorToStream(regionAndIntervals).flatMap(regionToVariants(hcEngineBroadcast.getValue())).iterator();
};
}

Expand All @@ -216,15 +226,14 @@ private static Function<Tuple2<AssemblyRegion, SimpleInterval>, Stream<? extends
*
* This will be replaced by a parallel writer similar to what's done with {@link org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink}
*/
private void writeVariants(JavaRDD<VariantContext> variants) {
private void writeVariants(JavaRDD<VariantContext> variants, HaplotypeCallerEngine hcEngine) {
final List<VariantContext> collectedVariants = variants.collect();
final SAMSequenceDictionary referenceDictionary = getReferenceSequenceDictionary();

final List<VariantContext> sortedVariants = collectedVariants.stream()
.sorted((o1, o2) -> IntervalUtils.compareLocatables(o1, o2, referenceDictionary))
.collect(Collectors.toList());

final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgs, getHeaderForReads(), new ReferenceMultiSourceAdapter(getReference(), getAuthHolder()));
try(final VariantContextWriter writer = hcEngine.makeVCFWriter(output, getBestAvailableSequenceDictionary())) {
hcEngine.writeHeader(writer, getHeaderForReads().getSequenceDictionary(), new HashSet<>());
sortedVariants.forEach(writer::add);
Expand All @@ -233,17 +242,12 @@ private void writeVariants(JavaRDD<VariantContext> variants) {

/**
* Create an RDD of {@link Shard} from an RDD of {@link GATKRead}
* @param shardBoundariesBroadcast broadcast of an {@link OverlapDetector} loaded with the intervals that should be used for creating ReadShards
* @param shardBoundaries the shard boundaries for creating shards from
* @param reads Rdd of {@link GATKRead}
* @return a Rdd of reads grouped into potentially overlapping shards
*/
private static JavaRDD<Shard<GATKRead>> createReadShards(final Broadcast<OverlapDetector<ShardBoundary>> shardBoundariesBroadcast, final JavaRDD<GATKRead> reads) {
final JavaPairRDD<ShardBoundary, GATKRead> paired = reads.flatMapToPair(read -> {
final Collection<ShardBoundary> overlappingShards = shardBoundariesBroadcast.value().getOverlaps(read);
return overlappingShards.stream().map(key -> new Tuple2<>(key, read)).iterator();
});
final JavaPairRDD<ShardBoundary, Iterable<GATKRead>> shardsWithReads = paired.groupByKey();
return shardsWithReads.map(shard -> new SparkReadShard(shard._1(), shard._2()));
private static JavaRDD<Shard<GATKRead>> createReadShards(final JavaSparkContext ctx, final List<ShardBoundary> shardBoundaries, final JavaRDD<GATKRead> reads, final SAMFileHeader header, int maxLocatableSize) {
return SparkSharder.shard(ctx, reads, GATKRead.class, header.getSequenceDictionary(), shardBoundaries, maxLocatableSize);
}

/**
Expand All @@ -258,6 +262,17 @@ private static OverlapDetector<ShardBoundary> getShardBoundaryOverlapDetector(fi
return shardBoundaryOverlapDetector;
}

/**
* @return a list of {@link ShardBoundary}
* based on the -L intervals
*/
private static List<ShardBoundary> getShardBoundaries(final SAMFileHeader
header, final List<SimpleInterval> intervals, final int readShardSize, final int readShardPadding) {
return intervals.stream()
.flatMap(interval -> Shard.divideIntervalIntoShards(interval, readShardSize, readShardPadding, header.getSequenceDictionary()).stream())
.collect(Collectors.toList());
}

/**
* @return and RDD of {@link Tuple2<AssemblyRegion, SimpleInterval>} which pairs each AssemblyRegion with the
* interval it was generated in
Expand All @@ -267,13 +282,24 @@ private static FlatMapFunction<Iterator<Shard<GATKRead>>, Tuple2<AssemblyRegion,
final Broadcast<ReferenceMultiSource> reference,
final Broadcast<HaplotypeCallerArgumentCollection> hcArgsBroadcast,
final ShardingArgumentCollection assemblyArgs,
final SAMFileHeader header) {
final SAMFileHeader header,
final Broadcast<HaplotypeCallerEngine> hcEngineBroadcast) {
return shards -> {
final ReferenceMultiSource referenceMultiSource = reference.value();
final ReferenceMultiSourceAdapter referenceSource = new ReferenceMultiSourceAdapter(referenceMultiSource, authHolder);
final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgsBroadcast.value(), header, referenceSource);

return iteratorToStream(shards).flatMap(shardToRegion(assemblyArgs, header, referenceSource, hcEngine)).iterator();
final HaplotypeCallerEngine hcEngine = hcEngineBroadcast.getValue();

ReadsDownsampler readsDownsampler = assemblyArgs.maxReadsPerAlignmentStart > 0 ?
new PositionalDownsampler(assemblyArgs.maxReadsPerAlignmentStart, header) : null;
return iteratorToStream(shards)
.map(shard -> {
DownsampleableSparkReadShard downsampledShard = new
DownsampleableSparkReadShard(new ShardBoundary(shard
.getInterval(), shard.getPaddedInterval()), shard);
downsampledShard.setDownsampler(readsDownsampler);
return downsampledShard;
})
.flatMap(shardToRegion(assemblyArgs, header, referenceSource, hcEngine)).iterator();
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.broadinstitute.hellbender.utils.variant.GATKVCFHeaderLines;
import org.broadinstitute.hellbender.utils.variant.GATKVariantContextUtils;

import java.io.IOException;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand All @@ -38,7 +39,13 @@ public abstract class GenotypingEngine<Config extends StandardCallerArgumentColl

protected VariantAnnotatorEngine annotationEngine;

protected Logger logger;
protected transient Logger logger;

private void readObject(java.io.ObjectInputStream in)
throws IOException, ClassNotFoundException {
in.defaultReadObject();
logger = LogManager.getLogger(this.getClass()); // Logger is not serializable (even by Kryo)
}

protected final int numberOfGenomes;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ public final class IndexedSet<E> extends AbstractSet<E> {
*/
private final Map<E, Integer> indexByElement;

/**
* Creates an empty indexed set.
*/
public IndexedSet() {
elements = new ArrayList<>();
indexByElement = new LinkedHashMap<>();
}

/**
* Creates an empty indexed set indicating the expected number of elements.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeCalculationArgumentCollection;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCallerArgumentCollection;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.collections.IndexedSet;
import org.broadinstitute.hellbender.utils.test.BaseTest;
import org.broadinstitute.hellbender.utils.test.SparkTestUtils;
import org.testng.Assert;
Expand Down Expand Up @@ -210,4 +211,11 @@ public void testAllelesAreSerializable() {
SparkTestUtils.roundTripInKryo(a, a.getClass(), SparkContextFactory.getTestSparkContext().getConf());
SparkTestUtils.roundTripInKryo(Allele.NO_CALL, Allele.class, SparkContextFactory.getTestSparkContext().getConf());
}

@Test
public void testIndexedSetIsSerializable() {
IndexedSet<String> set = new IndexedSet<>();
set.add("a");
SparkTestUtils.roundTripInKryo(set, IndexedSet.class, SparkContextFactory.getTestSparkContext().getConf());
}
}

0 comments on commit 9df50e7

Please sign in to comment.