Skip to content

Commit

Permalink
Support unmapped reads in Spark.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Aug 15, 2017
1 parent d4f2fce commit 8e90c46
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.broadinstitute.hellbender.cmdline.GATKPlugin.GATKReadFilterPluginDescriptor;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.cmdline.argumentcollections.*;
import org.broadinstitute.hellbender.engine.TraversalParameters;
import org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource;
import org.broadinstitute.hellbender.engine.datasources.ReferenceWindowFunctions;
import org.broadinstitute.hellbender.engine.FeatureDataSource;
Expand Down Expand Up @@ -221,10 +222,18 @@ public JavaRDD<GATKRead> getReads() {
* @return all reads from our reads input(s) as a {@link JavaRDD}, bounded by intervals if specified, and unfiltered.
*/
public JavaRDD<GATKRead> getUnfilteredReads() {
// If no intervals were specified this will return all reads (mapped and unmapped)
TraversalParameters traversalParameters;
if ( hasIntervals() ) {
traversalParameters = intervalArgumentCollection.getTraversalParameters(getHeaderForReads().getSequenceDictionary());
} else {
traversalParameters = null;
}

// TODO: This if statement is a temporary hack until #959 gets resolved.
if (readInput.endsWith(".adam")) {
try {
return readsSource.getADAMReads(readInput, intervals, getHeaderForReads());
return readsSource.getADAMReads(readInput, traversalParameters, getHeaderForReads());
} catch (IOException e) {
throw new UserException("Failed to read ADAM file " + readInput, e);
}
Expand All @@ -234,8 +243,7 @@ public JavaRDD<GATKRead> getUnfilteredReads() {
throw new UserException.MissingReference("A reference file is required when using CRAM files.");
}
final String refPath = hasReference() ? referenceArguments.getReferenceFile().getAbsolutePath() : null;
// If no intervals were specified (intervals == null), this will return all reads (mapped and unmapped)
return readsSource.getParallelReads(readInput, refPath, intervals, bamPartitionSplitSize);
return readsSource.getParallelReads(readInput, refPath, traversalParameters, bamPartitionSplitSize);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.spark.broadcast.Broadcast;
import org.bdgenomics.formats.avro.AlignmentRecord;
import org.broadinstitute.hellbender.engine.ReadsDataSource;
import org.broadinstitute.hellbender.engine.TraversalParameters;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
Expand Down Expand Up @@ -69,24 +70,24 @@ public ReadsSparkSource(final JavaSparkContext ctx, final ValidationStringency v
* i.e., file:///path/to/bam.bam.
* @param readFileName file to load
* @param referencePath Reference path or null if not available. Reference is required for CRAM files.
* @param intervals intervals of reads to include.
* @param traversalParameters parameters controlling which reads to include. If <code>null</code> then all the reads (both mapped and unmapped) will be returned.
* @return RDD of (SAMRecord-backed) GATKReads from the file.
*/
public JavaRDD<GATKRead> getParallelReads(final String readFileName, final String referencePath, final List<SimpleInterval> intervals) {
return getParallelReads(readFileName, referencePath, intervals, 0);
public JavaRDD<GATKRead> getParallelReads(final String readFileName, final String referencePath, final TraversalParameters traversalParameters) {
return getParallelReads(readFileName, referencePath, traversalParameters, 0);
}

/**
* Loads Reads using Hadoop-BAM. For local files, bam must have the fully-qualified path,
* i.e., file:///path/to/bam.bam.
* @param readFileName file to load
* @param referencePath Reference path or null if not available. Reference is required for CRAM files.
* @param intervals intervals of reads to include. If <code>null</code> then all the reads (both mapped and unmapped) will be returned.
* @param traversalParameters parameters controlling which reads to include. If <code>null</code> then all the reads (both mapped and unmapped) will be returned.
* @param splitSize maximum bytes of bam file to read into a single partition, increasing this will result in fewer partitions. A value of zero means
* use the default split size (determined by the Hadoop input format, typically the size of one HDFS block).
* @return RDD of (SAMRecord-backed) GATKReads from the file.
*/
public JavaRDD<GATKRead> getParallelReads(final String readFileName, final String referencePath, final List<SimpleInterval> intervals, final long splitSize) {
public JavaRDD<GATKRead> getParallelReads(final String readFileName, final String referencePath, final TraversalParameters traversalParameters, final long splitSize) {
SAMFileHeader header = getHeader(readFileName, referencePath);

// use the Hadoop configuration attached to the Spark context to maintain cumulative settings
Expand All @@ -100,19 +101,21 @@ public JavaRDD<GATKRead> getParallelReads(final String readFileName, final Strin
setHadoopBAMConfigurationProperties(readFileName, referencePath);

boolean isBam = IOUtils.isBamFileName(readFileName);
if (isBam && intervals != null && !intervals.isEmpty()) {
BAMInputFormat.setIntervals(conf, intervals);
} else {
conf.unset(BAMInputFormat.INTERVALS_PROPERTY);
if (isBam) {
if (traversalParameters == null) {
BAMInputFormat.unsetTraversalParameters(conf);
} else {
BAMInputFormat.setTraversalParameters(conf, traversalParameters.getIntervalsForTraversal(), traversalParameters.traverseUnmappedReads());
}
}

rdd2 = ctx.newAPIHadoopFile(
readFileName, AnySAMInputFormat.class, LongWritable.class, SAMRecordWritable.class,
conf);
readFileName, AnySAMInputFormat.class, LongWritable.class, SAMRecordWritable.class,
conf);

JavaRDD<GATKRead> reads= rdd2.map(v1 -> {
SAMRecord sam = v1._2().get();
if (isBam || samRecordOverlaps(sam, intervals)) { // don't check overlaps for BAM since it is done by input format
if (isBam || samRecordOverlaps(sam, traversalParameters)) { // don't check overlaps for BAM since it is done by input format
return (GATKRead) SAMRecordToGATKReadAdapter.headerlessReadAdapter(sam);
}
return null;
Expand Down Expand Up @@ -149,7 +152,7 @@ public JavaRDD<GATKRead> getParallelReads(final String readFileName, final Strin
* @param inputPath path to the Parquet data
* @return RDD of (ADAM-backed) GATKReads from the file.
*/
public JavaRDD<GATKRead> getADAMReads(final String inputPath, final List<SimpleInterval> intervals, final SAMFileHeader header) throws IOException {
public JavaRDD<GATKRead> getADAMReads(final String inputPath, final TraversalParameters traversalParameters, final SAMFileHeader header) throws IOException {
Job job = Job.getInstance(ctx.hadoopConfiguration());
AvroParquetInputFormat.setAvroReadSchema(job, AlignmentRecord.getClassSchema());
Broadcast<SAMFileHeader> bHeader;
Expand All @@ -163,7 +166,7 @@ public JavaRDD<GATKRead> getADAMReads(final String inputPath, final List<SimpleI
inputPath, AvroParquetInputFormat.class, Void.class, AlignmentRecord.class, job.getConfiguration())
.values();
JavaRDD<GATKRead> readsRdd = recordsRdd.map(record -> new BDGAlignmentRecordToGATKReadAdapter(record, bHeader.getValue()));
JavaRDD<GATKRead> filteredRdd = readsRdd.filter(record -> samRecordOverlaps(record.convertToSAMRecord(header), intervals));
JavaRDD<GATKRead> filteredRdd = readsRdd.filter(record -> samRecordOverlaps(record.convertToSAMRecord(header), traversalParameters));
return putPairsInSamePartition(header, filteredRdd);
}

Expand Down Expand Up @@ -287,10 +290,17 @@ private void setHadoopBAMConfigurationProperties(final String inputName, final S
* formats that don't support query-by-interval natively at the Hadoop-BAM layer.
*/
//TODO: use IntervalsSkipList, see https://github.com/broadinstitute/gatk/issues/1531
private static boolean samRecordOverlaps(final SAMRecord record, final List<SimpleInterval> intervals ) {
if (intervals == null || intervals.isEmpty()) {
private static boolean samRecordOverlaps(final SAMRecord record, final TraversalParameters traversalParameters ) {
if (traversalParameters == null) {
return true;
}
if (traversalParameters.traverseUnmappedReads() && record.getReadUnmappedFlag() && record.getAlignmentStart() == SAMRecord.NO_ALIGNMENT_START) {
return true; // include record if unmapped records should be traversed and record is unmapped
}
List<SimpleInterval> intervals = traversalParameters.getIntervalsForTraversal();
if (intervals == null || intervals.isEmpty()) {
return false; // no intervals means 'no mapped reads'
}
for (SimpleInterval interval : intervals) {
if (record.getReadUnmappedFlag() && record.getAlignmentStart() != SAMRecord.NO_ALIGNMENT_START) {
// This follows the behavior of htsjdk's SamReader which states that "an unmapped read will be returned
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.broadinstitute.barclay.argparser.BetaFeature;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.hellbender.cmdline.programgroups.TestSparkProgramGroup;
import org.broadinstitute.hellbender.engine.TraversalParameters;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource;
import org.broadinstitute.hellbender.exceptions.GATKException;
Expand Down Expand Up @@ -55,7 +56,13 @@ protected void runTool(final JavaSparkContext ctx) {
JavaRDD<GATKRead> firstReads = filteredReads(getReads(), readArguments.getReadFilesNames().get(0));

ReadsSparkSource readsSource2 = new ReadsSparkSource(ctx, readArguments.getReadValidationStringency());
JavaRDD<GATKRead> secondReads = filteredReads(readsSource2.getParallelReads(input2, null, getIntervals(), bamPartitionSplitSize), input2);
TraversalParameters traversalParameters;
if ( hasIntervals() ) {
traversalParameters = intervalArgumentCollection.getTraversalParameters(getHeaderForReads().getSequenceDictionary());
} else {
traversalParameters = null;
}
JavaRDD<GATKRead> secondReads = filteredReads(readsSource2.getParallelReads(input2, null, traversalParameters, bamPartitionSplitSize), input2);

// Start by verifying that we have same number of reads and duplicates in each BAM.
long firstBamSize = firstReads.count();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.broadinstitute.hellbender.engine.ReadsDataSource;
import org.broadinstitute.hellbender.engine.TraversalParameters;
import org.broadinstitute.hellbender.engine.spark.SparkContextFactory;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
Expand Down Expand Up @@ -219,7 +220,8 @@ public void testIntervals() throws IOException {
ReadsSparkSource readSource = new ReadsSparkSource(ctx);
List<SimpleInterval> intervals =
ImmutableList.of(new SimpleInterval("17", 69010, 69040), new SimpleInterval("17", 69910, 69920));
JavaRDD<GATKRead> reads = readSource.getParallelReads(NA12878_chr17_1k_BAM, null, intervals);
TraversalParameters traversalParameters = new TraversalParameters(intervals, false);
JavaRDD<GATKRead> reads = readSource.getParallelReads(NA12878_chr17_1k_BAM, null, traversalParameters);

SamReaderFactory samReaderFactory = SamReaderFactory.makeDefault().validationStringency(ValidationStringency.SILENT);
try (SamReader samReader = samReaderFactory.open(new File(NA12878_chr17_1k_BAM))) {
Expand All @@ -229,6 +231,27 @@ public void testIntervals() throws IOException {
}
}

@Test(groups = "spark")
public void testIntervalsWithUnmapped() throws IOException {
String bam = publicTestDir + "org/broadinstitute/hellbender/engine/CEUTrio.HiSeq.WGS.b37.NA12878.snippet_with_unmapped.bam";
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
ReadsSparkSource readSource = new ReadsSparkSource(ctx);
List<SimpleInterval> intervals = ImmutableList.of(new SimpleInterval("20", 10000009, 10000011));
TraversalParameters traversalParameters = new TraversalParameters(intervals, true);
JavaRDD<GATKRead> reads = readSource.getParallelReads(bam, null, traversalParameters);

SamReaderFactory samReaderFactory = SamReaderFactory.makeDefault().validationStringency(ValidationStringency.SILENT);
try (SamReader samReader = samReaderFactory.open(new File(bam))) {
int seqIndex = samReader.getFileHeader().getSequenceIndex("20");
SAMRecordIterator query = samReader.query(new QueryInterval[]{new QueryInterval(seqIndex, 10000009, 10000011)}, false);
int queryReads = Iterators.size(query);
query.close();
SAMRecordIterator queryUnmapped = samReader.queryUnmapped();
queryReads += Iterators.size(queryUnmapped);
Assert.assertEquals(reads.count(), queryReads);
}
}

/**
* Loads Reads using samReaderFactory, then calling ctx.parallelize.
* @param bam file to load
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
import java.util.List;
import org.broadinstitute.hellbender.CommandLineProgramTest;
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
import org.broadinstitute.hellbender.engine.ReadsDataSource;
import org.broadinstitute.hellbender.engine.filters.ReadLengthReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadNameReadFilter;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.test.ArgumentsBuilder;
import org.broadinstitute.hellbender.utils.test.BaseTest;
import org.broadinstitute.hellbender.utils.test.IntegrationTestSpec;
Expand All @@ -23,6 +25,8 @@
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.util.*;

public final class PrintReadsSparkIntegrationTest extends CommandLineProgramTest {

private static final File TEST_DATA_DIR = getTestDataDir();
Expand Down Expand Up @@ -251,6 +255,68 @@ public void testNonExistentReference() throws Exception {
runCommandLine(args.getArgsArray());
}

@DataProvider(name = "UnmappedReadInclusionTestData")
public Object[][] unmappedReadInclusionTestData() {
// This bam has mapped reads from various contigs, plus a few unmapped reads with no mapped mate
final File unmappedBam = new File(publicTestDir + "org/broadinstitute/hellbender/engine/reads_data_source_test1_with_unmapped.bam");

// This is a snippet of the CEUTrio.HiSeq.WGS.b37.NA12878 bam from large, with mapped reads
// from chromosome 20 (with one mapped read having an unmapped mate), plus several unmapped
// reads with no mapped mate.
final File ceuSnippet = new File(publicTestDir + "org/broadinstitute/hellbender/engine/CEUTrio.HiSeq.WGS.b37.NA12878.snippet_with_unmapped.bam");
final File ceuSnippetCram = new File(publicTestDir + "org/broadinstitute/hellbender/engine/CEUTrio.HiSeq.WGS.b37.NA12878.snippet_with_unmapped.cram");

return new Object[][] {
{ unmappedBam, null, Arrays.asList("unmapped"), Arrays.asList("u1", "u2", "u3", "u4", "u5") },
// The same interval as above in an intervals file
{ unmappedBam, null, Arrays.asList(publicTestDir + "org/broadinstitute/hellbender/engine/reads_data_source_test1_unmapped.intervals"), Arrays.asList("u1", "u2", "u3", "u4", "u5") },
{ unmappedBam, null, Arrays.asList("1:200-300", "unmapped"), Arrays.asList("a", "b", "c", "u1", "u2", "u3", "u4", "u5") },
{ unmappedBam, null, Arrays.asList("1:200-300", "4:700-701", "unmapped"), Arrays.asList("a", "b", "c", "k", "u1", "u2", "u3", "u4", "u5") },
// The same intervals as above in an intervals file
{ unmappedBam, null, Arrays.asList(publicTestDir + "org/broadinstitute/hellbender/engine/reads_data_source_test1_unmapped2.intervals"), Arrays.asList("a", "b", "c", "k", "u1", "u2", "u3", "u4", "u5") },
{ ceuSnippet, null, Arrays.asList("unmapped"), Arrays.asList("g", "h", "h", "i", "i") },
{ ceuSnippet, null, Arrays.asList("20:10000009-10000011"), Arrays.asList("a", "b", "c", "d", "e") },
{ ceuSnippet, null, Arrays.asList("20:10000009-10000011", "unmapped"), Arrays.asList("a", "b", "c", "d", "e", "g", "h", "h", "i", "i") },
{ ceuSnippet, null, Arrays.asList("20:10000009-10000013", "unmapped"), Arrays.asList("a", "b", "c", "d", "e", "f", "f", "g", "h", "h", "i", "i") },
{ ceuSnippetCram, b37_reference_20_21, Arrays.asList("unmapped"), Arrays.asList("g", "h", "h", "i", "i") },
{ ceuSnippetCram, b37_reference_20_21, Arrays.asList("20:10000009-10000011", "unmapped"), Arrays.asList("a", "b", "c", "d", "e", "g", "h", "h", "i", "i") },
{ ceuSnippetCram, b37_reference_20_21, Arrays.asList("20:10000009-10000013", "unmapped"), Arrays.asList("a", "b", "c", "d", "e", "f", "f", "g", "h", "h", "i", "i") }
};
}

@Test(dataProvider = "UnmappedReadInclusionTestData")
public void testUnmappedReadInclusion( final File input, final String reference, final List<String> intervalStrings, final List<String> expectedReadNames ) {
final File outFile = createTempFile("testUnmappedReadInclusion", ".bam");

final ArgumentsBuilder args = new ArgumentsBuilder();
args.add("-I"); args.add(input.getAbsolutePath());
args.add("-O"); args.add(outFile.getAbsolutePath());
for ( final String intervalString : intervalStrings ) {
args.add("-L"); args.add(intervalString);
}
if ( reference != null ) {
args.add("-R"); args.add(reference);
}

runCommandLine(args);

try ( final ReadsDataSource outputReadsSource = new ReadsDataSource(outFile.toPath()) ) {
final List<GATKRead> actualReads = new ArrayList<>();
for ( final GATKRead read : outputReadsSource ) {
actualReads.add(read);
}

if (actualReads.size() != expectedReadNames.size()) {
System.out.println("actual: " + actualReads);
System.out.println("expectedReadNames: " + expectedReadNames);
}
Assert.assertEquals(actualReads.size(), expectedReadNames.size(), "Wrong number of reads output");

for ( int readNumber = 0; readNumber < actualReads.size(); ++readNumber ) {
Assert.assertEquals(actualReads.get(readNumber).getName(), expectedReadNames.get(readNumber), "Unexpected read name");
}
}
}

@DataProvider(name="readFilterTestData")
public Object[][] testReadFilterData() {
Expand Down

0 comments on commit 8e90c46

Please sign in to comment.