Skip to content

Commit

Permalink
Support unmapped reads in Spark. (#3369)
Browse files Browse the repository at this point in the history
Fixes #2571 and #2572
  • Loading branch information
tomwhite authored Sep 7, 2017
1 parent 28c3b7d commit b06340f
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 22 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ repositories {

final htsjdkVersion = System.getProperty('htsjdk.version','2.11.0-4-g958dc6e-SNAPSHOT')
final sparkVersion = System.getProperty('spark.version', '2.0.2')
final hadoopBamVersion = System.getProperty('hadoopBam.version','7.8.0')
final hadoopBamVersion = System.getProperty('hadoopBam.version','7.9.0')
final genomicsdbVersion = System.getProperty('genomicsdb.version','0.6.4-proto-3.0.0-beta-1')
final testNGVersion = '6.11'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.broadinstitute.hellbender.cmdline.GATKPlugin.GATKReadFilterPluginDescriptor;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.cmdline.argumentcollections.*;
import org.broadinstitute.hellbender.engine.TraversalParameters;
import org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource;
import org.broadinstitute.hellbender.engine.datasources.ReferenceWindowFunctions;
import org.broadinstitute.hellbender.engine.FeatureDataSource;
Expand Down Expand Up @@ -221,10 +222,19 @@ public JavaRDD<GATKRead> getReads() {
* @return all reads from our reads input(s) as a {@link JavaRDD}, bounded by intervals if specified, and unfiltered.
*/
public JavaRDD<GATKRead> getUnfilteredReads() {
TraversalParameters traversalParameters;
if ( intervalArgumentCollection.intervalsSpecified() ) {
traversalParameters = intervalArgumentCollection.getTraversalParameters(getHeaderForReads().getSequenceDictionary());
} else if ( hasIntervals() ) { // intervals may have been supplied by editIntervals
traversalParameters = new TraversalParameters(getIntervals(), false);
} else {
traversalParameters = null; // no intervals were specified so return all reads (mapped and unmapped)
}

// TODO: This if statement is a temporary hack until #959 gets resolved.
if (readInput.endsWith(".adam")) {
try {
return readsSource.getADAMReads(readInput, intervals, getHeaderForReads());
return readsSource.getADAMReads(readInput, traversalParameters, getHeaderForReads());
} catch (IOException e) {
throw new UserException("Failed to read ADAM file " + readInput, e);
}
Expand All @@ -234,8 +244,7 @@ public JavaRDD<GATKRead> getUnfilteredReads() {
throw new UserException.MissingReference("A reference file is required when using CRAM files.");
}
final String refPath = hasReference() ? referenceArguments.getReferenceFile().getAbsolutePath() : null;
// If no intervals were specified (intervals == null), this will return all reads (mapped and unmapped)
return readsSource.getParallelReads(readInput, refPath, intervals, bamPartitionSplitSize);
return readsSource.getParallelReads(readInput, refPath, traversalParameters, bamPartitionSplitSize);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.spark.broadcast.Broadcast;
import org.bdgenomics.formats.avro.AlignmentRecord;
import org.broadinstitute.hellbender.engine.ReadsDataSource;
import org.broadinstitute.hellbender.engine.TraversalParameters;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
Expand Down Expand Up @@ -69,24 +70,24 @@ public ReadsSparkSource(final JavaSparkContext ctx, final ValidationStringency v
* i.e., file:///path/to/bam.bam.
* @param readFileName file to load
* @param referencePath Reference path or null if not available. Reference is required for CRAM files.
* @param intervals intervals of reads to include.
* @param traversalParameters parameters controlling which reads to include. If <code>null</code> then all the reads (both mapped and unmapped) will be returned.
* @return RDD of (SAMRecord-backed) GATKReads from the file.
*/
public JavaRDD<GATKRead> getParallelReads(final String readFileName, final String referencePath, final List<SimpleInterval> intervals) {
return getParallelReads(readFileName, referencePath, intervals, 0);
public JavaRDD<GATKRead> getParallelReads(final String readFileName, final String referencePath, final TraversalParameters traversalParameters) {
return getParallelReads(readFileName, referencePath, traversalParameters, 0);
}

/**
* Loads Reads using Hadoop-BAM. For local files, bam must have the fully-qualified path,
* i.e., file:///path/to/bam.bam.
* @param readFileName file to load
* @param referencePath Reference path or null if not available. Reference is required for CRAM files.
* @param intervals intervals of reads to include. If <code>null</code> then all the reads (both mapped and unmapped) will be returned.
* @param traversalParameters parameters controlling which reads to include. If <code>null</code> then all the reads (both mapped and unmapped) will be returned.
* @param splitSize maximum bytes of bam file to read into a single partition, increasing this will result in fewer partitions. A value of zero means
* use the default split size (determined by the Hadoop input format, typically the size of one HDFS block).
* @return RDD of (SAMRecord-backed) GATKReads from the file.
*/
public JavaRDD<GATKRead> getParallelReads(final String readFileName, final String referencePath, final List<SimpleInterval> intervals, final long splitSize) {
public JavaRDD<GATKRead> getParallelReads(final String readFileName, final String referencePath, final TraversalParameters traversalParameters, final long splitSize) {
SAMFileHeader header = getHeader(readFileName, referencePath);

// use the Hadoop configuration attached to the Spark context to maintain cumulative settings
Expand All @@ -100,19 +101,21 @@ public JavaRDD<GATKRead> getParallelReads(final String readFileName, final Strin
setHadoopBAMConfigurationProperties(readFileName, referencePath);

boolean isBam = IOUtils.isBamFileName(readFileName);
if (isBam && intervals != null && !intervals.isEmpty()) {
BAMInputFormat.setIntervals(conf, intervals);
} else {
conf.unset(BAMInputFormat.INTERVALS_PROPERTY);
if (isBam) {
if (traversalParameters == null) {
BAMInputFormat.unsetTraversalParameters(conf);
} else {
BAMInputFormat.setTraversalParameters(conf, traversalParameters.getIntervalsForTraversal(), traversalParameters.traverseUnmappedReads());
}
}

rdd2 = ctx.newAPIHadoopFile(
readFileName, AnySAMInputFormat.class, LongWritable.class, SAMRecordWritable.class,
conf);
readFileName, AnySAMInputFormat.class, LongWritable.class, SAMRecordWritable.class,
conf);

JavaRDD<GATKRead> reads= rdd2.map(v1 -> {
SAMRecord sam = v1._2().get();
if (isBam || samRecordOverlaps(sam, intervals)) { // don't check overlaps for BAM since it is done by input format
if (isBam || samRecordOverlaps(sam, traversalParameters)) { // don't check overlaps for BAM since it is done by input format
return (GATKRead) SAMRecordToGATKReadAdapter.headerlessReadAdapter(sam);
}
return null;
Expand Down Expand Up @@ -149,7 +152,7 @@ public JavaRDD<GATKRead> getParallelReads(final String readFileName, final Strin
* @param inputPath path to the Parquet data
* @return RDD of (ADAM-backed) GATKReads from the file.
*/
public JavaRDD<GATKRead> getADAMReads(final String inputPath, final List<SimpleInterval> intervals, final SAMFileHeader header) throws IOException {
public JavaRDD<GATKRead> getADAMReads(final String inputPath, final TraversalParameters traversalParameters, final SAMFileHeader header) throws IOException {
Job job = Job.getInstance(ctx.hadoopConfiguration());
AvroParquetInputFormat.setAvroReadSchema(job, AlignmentRecord.getClassSchema());
Broadcast<SAMFileHeader> bHeader;
Expand All @@ -163,7 +166,7 @@ public JavaRDD<GATKRead> getADAMReads(final String inputPath, final List<SimpleI
inputPath, AvroParquetInputFormat.class, Void.class, AlignmentRecord.class, job.getConfiguration())
.values();
JavaRDD<GATKRead> readsRdd = recordsRdd.map(record -> new BDGAlignmentRecordToGATKReadAdapter(record, bHeader.getValue()));
JavaRDD<GATKRead> filteredRdd = readsRdd.filter(record -> samRecordOverlaps(record.convertToSAMRecord(header), intervals));
JavaRDD<GATKRead> filteredRdd = readsRdd.filter(record -> samRecordOverlaps(record.convertToSAMRecord(header), traversalParameters));
return putPairsInSamePartition(header, filteredRdd);
}

Expand Down Expand Up @@ -287,10 +290,17 @@ private void setHadoopBAMConfigurationProperties(final String inputName, final S
* formats that don't support query-by-interval natively at the Hadoop-BAM layer.
*/
//TODO: use IntervalsSkipList, see https://github.com/broadinstitute/gatk/issues/1531
private static boolean samRecordOverlaps(final SAMRecord record, final List<SimpleInterval> intervals ) {
if (intervals == null || intervals.isEmpty()) {
private static boolean samRecordOverlaps(final SAMRecord record, final TraversalParameters traversalParameters ) {
if (traversalParameters == null) {
return true;
}
if (traversalParameters.traverseUnmappedReads() && record.getReadUnmappedFlag() && record.getAlignmentStart() == SAMRecord.NO_ALIGNMENT_START) {
return true; // include record if unmapped records should be traversed and record is unmapped
}
List<SimpleInterval> intervals = traversalParameters.getIntervalsForTraversal();
if (intervals == null || intervals.isEmpty()) {
return false; // no intervals means 'no mapped reads'
}
for (SimpleInterval interval : intervals) {
if (record.getReadUnmappedFlag() && record.getAlignmentStart() != SAMRecord.NO_ALIGNMENT_START) {
// This follows the behavior of htsjdk's SamReader which states that "an unmapped read will be returned
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.broadinstitute.barclay.argparser.BetaFeature;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.hellbender.cmdline.programgroups.TestSparkProgramGroup;
import org.broadinstitute.hellbender.engine.TraversalParameters;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource;
import org.broadinstitute.hellbender.exceptions.GATKException;
Expand Down Expand Up @@ -55,7 +56,13 @@ protected void runTool(final JavaSparkContext ctx) {
JavaRDD<GATKRead> firstReads = filteredReads(getReads(), readArguments.getReadFilesNames().get(0));

ReadsSparkSource readsSource2 = new ReadsSparkSource(ctx, readArguments.getReadValidationStringency());
JavaRDD<GATKRead> secondReads = filteredReads(readsSource2.getParallelReads(input2, null, getIntervals(), bamPartitionSplitSize), input2);
TraversalParameters traversalParameters;
if ( hasIntervals() ) {
traversalParameters = intervalArgumentCollection.getTraversalParameters(getHeaderForReads().getSequenceDictionary());
} else {
traversalParameters = null;
}
JavaRDD<GATKRead> secondReads = filteredReads(readsSource2.getParallelReads(input2, null, traversalParameters, bamPartitionSplitSize), input2);

// Start by verifying that we have same number of reads and duplicates in each BAM.
long firstBamSize = firstReads.count();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.broadinstitute.hellbender.engine.ReadsDataSource;
import org.broadinstitute.hellbender.engine.TraversalParameters;
import org.broadinstitute.hellbender.engine.spark.SparkContextFactory;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
Expand Down Expand Up @@ -219,7 +220,8 @@ public void testIntervals() throws IOException {
ReadsSparkSource readSource = new ReadsSparkSource(ctx);
List<SimpleInterval> intervals =
ImmutableList.of(new SimpleInterval("17", 69010, 69040), new SimpleInterval("17", 69910, 69920));
JavaRDD<GATKRead> reads = readSource.getParallelReads(NA12878_chr17_1k_BAM, null, intervals);
TraversalParameters traversalParameters = new TraversalParameters(intervals, false);
JavaRDD<GATKRead> reads = readSource.getParallelReads(NA12878_chr17_1k_BAM, null, traversalParameters);

SamReaderFactory samReaderFactory = SamReaderFactory.makeDefault().validationStringency(ValidationStringency.SILENT);
try (SamReader samReader = samReaderFactory.open(new File(NA12878_chr17_1k_BAM))) {
Expand All @@ -229,6 +231,27 @@ public void testIntervals() throws IOException {
}
}

@Test(groups = "spark")
public void testIntervalsWithUnmapped() throws IOException {
String bam = publicTestDir + "org/broadinstitute/hellbender/engine/CEUTrio.HiSeq.WGS.b37.NA12878.snippet_with_unmapped.bam";
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
ReadsSparkSource readSource = new ReadsSparkSource(ctx);
List<SimpleInterval> intervals = ImmutableList.of(new SimpleInterval("20", 10000009, 10000011));
TraversalParameters traversalParameters = new TraversalParameters(intervals, true);
JavaRDD<GATKRead> reads = readSource.getParallelReads(bam, null, traversalParameters);

SamReaderFactory samReaderFactory = SamReaderFactory.makeDefault().validationStringency(ValidationStringency.SILENT);
try (SamReader samReader = samReaderFactory.open(new File(bam))) {
int seqIndex = samReader.getFileHeader().getSequenceIndex("20");
SAMRecordIterator query = samReader.query(new QueryInterval[]{new QueryInterval(seqIndex, 10000009, 10000011)}, false);
int queryReads = Iterators.size(query);
query.close();
SAMRecordIterator queryUnmapped = samReader.queryUnmapped();
queryReads += Iterators.size(queryUnmapped);
Assert.assertEquals(reads.count(), queryReads);
}
}

/**
* Loads Reads using samReaderFactory, then calling ctx.parallelize.
* @param bam file to load
Expand Down
Loading

0 comments on commit b06340f

Please sign in to comment.