Skip to content

Commit

Permalink
responded to david benjamins comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesemery committed Sep 17, 2019
1 parent 6c4abdd commit 7c91e9a
Show file tree
Hide file tree
Showing 12 changed files with 457 additions and 486 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ public GraphBasedKBestHaplotypeFinder(final BaseGraph<V, E> graph, final Set<V>
super(sinks, sources, graph);
}

@Override
protected BaseGraph<V, E> removeCyclesIfNecessary(BaseGraph<V, E> graph, Set<V> sources, Set<V> sinks) {
return new CycleDetector<>(graph).detectCycles() ? removeCyclesAndVerticesThatDontLeadToSinks(graph,sources,sinks) : graph;
}

/**
* Constructor for the special case of a single source and sink
*/
Expand All @@ -49,6 +44,11 @@ public GraphBasedKBestHaplotypeFinder(final BaseGraph<V, E> graph) {
this(graph, graph.getSources(), graph.getSinks());
}

@Override
public boolean keepCycles() {
return false;
}

/**
* Implement Dijkstra's algorithm as described in https://en.wikipedia.org/wiki/K_shortest_path_routing
*/
Expand Down Expand Up @@ -84,71 +84,4 @@ public List<KBestHaplotype<V, E>> findBestHaplotypes(final int maxNumberOfHaplot
return result;
}

/**
* Removes edges that produces cycles and also dead vertices that do not lead to any sink vertex.
* @return never {@code null}.
*/
protected BaseGraph<V, E> removeCyclesAndVerticesThatDontLeadToSinks(final BaseGraph<V, E> original, final Collection<V> sources, final Set<V> sinks) {
final Set<E> edgesToRemove = new HashSet<>(original.edgeSet().size());
final Set<V> vertexToRemove = new HashSet<>(original.vertexSet().size());

boolean foundSomePath = false;
for (final V source : sources) {
final Set<V> parentVertices = new HashSet<>(original.vertexSet().size());
foundSomePath = findGuiltyVerticesAndEdgesToRemoveCycles(original, source, sinks, edgesToRemove, vertexToRemove, parentVertices) || foundSomePath;
}

Utils.validate(foundSomePath, () -> "could not find any path from the source vertex to the sink vertex after removing cycles: "
+ Arrays.toString(sources.toArray()) + " => " + Arrays.toString(sinks.toArray()));

Utils.validate(!(edgesToRemove.isEmpty() && vertexToRemove.isEmpty()), "cannot find a way to remove the cycles");

final BaseGraph<V, E> result = original.clone();
result.removeAllEdges(edgesToRemove);
result.removeAllVertices(vertexToRemove);
return result;
}

/**
* Recursive call that looks for edges and vertices that need to be removed to get rid of cycles.
*
* @param graph the original graph.
* @param currentVertex current search vertex.
* @param sinks considered sink vertices.
* @param edgesToRemove collection of edges that need to be removed in order to get rid of cycles.
* @param verticesToRemove collection of vertices that can be removed.
* @param parentVertices collection of vertices that preceded the {@code currentVertex}; i.e. the it can be
* reached from those vertices using edges existing in {@code graph}.
*
* @return {@code true} to indicate that the some sink vertex is reachable by {@code currentVertex},
* {@code false} otherwise.
*/
private boolean findGuiltyVerticesAndEdgesToRemoveCycles(final BaseGraph<V, E> graph,
final V currentVertex,
final Set<V> sinks,
final Set<E> edgesToRemove,
final Set<V> verticesToRemove,
final Set<V> parentVertices) {
if (sinks.contains(currentVertex)) {
return true;
}

final Set<E> outgoingEdges = graph.outgoingEdgesOf(currentVertex);
parentVertices.add(currentVertex);

boolean reachesSink = false;
for (final E edge : outgoingEdges) {
final V child = graph.getEdgeTarget(edge);
if (parentVertices.contains(child)) {
edgesToRemove.add(edge);
} else {
final boolean childReachSink = findGuiltyVerticesAndEdgesToRemoveCycles(graph, child, sinks, edgesToRemove, verticesToRemove, parentVertices);
reachesSink = reachesSink || childReachSink;
}
}
if (!reachesSink) {
verticesToRemove.add(currentVertex);
}
return reachesSink;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,57 +17,55 @@
* data off of the list as well as incrementing all of the trees in the list to point at the next element based on the chosen path.
*/
public class JTBestHaplotype<V extends BaseVertex, E extends BaseEdge> extends KBestHaplotype<V, E> {
private JunctionTreeManager treesInQueue; // An object for storing and managing operations on the queue of junction trees active for this path
private JunctionTreeManager junctionTreeManager; // An object for storing and managing operations on the queue of junction trees active for this path
private int decisionEdgesTakenSinceLastJunctionTreeEvidence;

// NOTE, this constructor is used by JunctionTreeKBestHaplotypeFinder, in both cases paths are chosen by non-junction tree paths
public JTBestHaplotype(final JTBestHaplotype<V, E> previousPath, final List<E> edgesToExtend, final double edgePenalty) {
super(previousPath, edgesToExtend, edgePenalty);
treesInQueue = new JunctionTreeManager(previousPath.treesInQueue);
decisionEdgesTakenSinceLastJunctionTreeEvidence = treesInQueue.hasJunctionTreeEvidence() ? 0 : previousPath.decisionEdgesTakenSinceLastJunctionTreeEvidence;
junctionTreeManager = new JunctionTreeManager(previousPath.junctionTreeManager);
decisionEdgesTakenSinceLastJunctionTreeEvidence = junctionTreeManager.hasJunctionTreeEvidence() ? 0 : previousPath.decisionEdgesTakenSinceLastJunctionTreeEvidence;
}

// Constructor to be used for internal calls from {@link #getApplicableNextEdgesBasedOnJunctionTrees()}
private JTBestHaplotype(final JTBestHaplotype<V, E> previousPath, final List<E> chain, final int edgeMultiplicity, final int totalOutgoingMultiplicity, final boolean thisPathBasedOnJT) {
super(previousPath, chain, computeLogPenaltyScore( edgeMultiplicity, totalOutgoingMultiplicity));
treesInQueue = new JunctionTreeManager(previousPath.treesInQueue);
treesInQueue.traverseEdgeForAllTrees(chain.get(chain.size() - 1));
junctionTreeManager = new JunctionTreeManager(previousPath.junctionTreeManager);
junctionTreeManager.traverseEdgeForAllTrees(chain.get(chain.size() - 1));
// I'm aware that the chain is only an estimate of the proper length, especially if we got here due to being under the weight threshold for a given tree... the chain lenght is a heuristic as it is...
decisionEdgesTakenSinceLastJunctionTreeEvidence = thisPathBasedOnJT ? 0 : previousPath.decisionEdgesTakenSinceLastJunctionTreeEvidence + 1;
}

// JTBestHaplotype constructor for construction an entirely new haplotype builder.
public JTBestHaplotype(final V initialVertex, final BaseGraph<V,E> graph) {
super(initialVertex, graph);
treesInQueue = new JunctionTreeManager();
junctionTreeManager = new JunctionTreeManager();
decisionEdgesTakenSinceLastJunctionTreeEvidence = 0;
}

public boolean hasJunctionTreeEvidence() {
return treesInQueue.hasJunctionTreeEvidence();
return junctionTreeManager.hasJunctionTreeEvidence();
}

//TODO this needs to be the same logic as the blow method, this is temporary
// returns true if there is a symbolic edge pointing to the reference end or if there is insufficient node data
public boolean hasStoppingEvidence(final int weightThreshold) {
int currentActiveNodeIndex = 0;
JunctionTreeLinkedDeBruinGraph.ThreadingNode oldestTree = !treesInQueue.hasJunctionTreeEvidence() ? null : treesInQueue.get(currentActiveNodeIndex);
int totalOut = getTotalOutForBranch(oldestTree);

// Keep removing trees until we find one under our threshold TODO this should be in a helper method
while (oldestTree != null && totalOut < weightThreshold) {
// We remove old branches from the tree only if they no longer have any evidence, otherwise we look at younger branches
if (totalOut <= 0) {
treesInQueue.removeEldestTree();
} else { // Otherwise look at the next tree in the list
currentActiveNodeIndex++;

// Traverse the non-empty trees until we find one with evidence over our threshold. If we ever find a symbolic end vertex then we stop.
for (JunctionTreeLinkedDeBruinGraph.ThreadingNode tree : junctionTreeManager.removeEmptyNodesAndReturnIterator()) {
int totalOut = getTotalOutForBranch(tree);

// Are any of these vertexes symbolic stops?
if (tree.getChildrenNodes().values().stream()
.anyMatch(JunctionTreeLinkedDeBruinGraph.ThreadingNode::isSymbolicEnd)) {
return true;
}
if ( totalOut >= weightThreshold) {
return false;
}
oldestTree = currentActiveNodeIndex >= treesInQueue.size() ? null : treesInQueue.get(currentActiveNodeIndex);
totalOut = getTotalOutForBranch(oldestTree);
}

return oldestTree == null || oldestTree.getChildrenNodes().values().stream()
.anyMatch(JunctionTreeLinkedDeBruinGraph.ThreadingNode::isSymbolicEnd);
// None of our junction trees cover the stop vertex, close it
return true;
}

// Tally the total outgoing weight for a particular branch
Expand All @@ -92,22 +90,18 @@ private static int getTotalOutForBranch(final JunctionTreeLinkedDeBruinGraph.Thr
//TODO for reviewer - is this the best way to structure this? I'm not sure how to decide about end nodes based on this, passing them back seesm wrong
@SuppressWarnings({"unchecked"})
public List<JTBestHaplotype<V, E>> getApplicableNextEdgesBasedOnJunctionTrees(final List<E> chain, final Set<E> outgoingEdges, final int weightThreshold) {
Set<MultiSampleEdge> edgesAccountedForByJunctionTrees = new HashSet<>(); // Since we check multiple junction trees for paths, make sure we are getting
Set<MultiSampleEdge> edgesAccountedForByJunctionTrees = new HashSet<>(); // Since we check multiple junction trees for paths, keep track of which paths we have taken to adding duplicate paths to the graph
List<JTBestHaplotype<V, E>> output = new ArrayList<>();
int currentActiveNodeIndex = 0;
JunctionTreeLinkedDeBruinGraph.ThreadingNode oldestTree = !treesInQueue.hasJunctionTreeEvidence() ? null : treesInQueue.get(currentActiveNodeIndex);
while (oldestTree != null) {
int totalOut = getTotalOutForBranch(oldestTree);
for ( JunctionTreeLinkedDeBruinGraph.ThreadingNode tree : junctionTreeManager.removeEmptyNodesAndReturnIterator()) {
int totalOut = getTotalOutForBranch(tree);

// If the total evidence emerging from a given branch

//TODO add SOME sanity check to ensure that the vertex we stand on and the edges we are polling line up
for (Map.Entry<MultiSampleEdge, JunctionTreeLinkedDeBruinGraph.ThreadingNode> childNode : oldestTree.getChildrenNodes().entrySet()) {
for (Map.Entry<MultiSampleEdge, JunctionTreeLinkedDeBruinGraph.ThreadingNode> childNode : tree.getChildrenNodes().entrySet()) {
if (!outgoingEdges.contains(childNode.getKey())) {
throw new GATKException("While constructing graph, there was an incongruity between a JunctionTree edge and the edge present on graph traversal");
}

// Don't add edges to the symbolic end vertex here at all, thats handled elsewhere, also don't add the same edge again if we pulled it in from a younger tree.
// Don't add edges to the symbolic end vertex here at all, that's handled by {@link #hasStoppingEvidence()}, also don't add the same edge again if we pulled it in from a younger tree.
if (!childNode.getValue().isSymbolicEnd() && // ignore symbolic end branches, those are handled elsewhere
!edgesAccountedForByJunctionTrees.contains(childNode.getKey())) {
edgesAccountedForByJunctionTrees.add(childNode.getKey());
Expand All @@ -118,17 +112,10 @@ public List<JTBestHaplotype<V, E>> getApplicableNextEdgesBasedOnJunctionTrees(fi
}
}

// If there isn't enough outgoing evidence, then we
if (totalOut < weightThreshold){
// We remove old branches from the tree only if they no longer have any evidence, otherwise we look at younger branches
if (totalOut <= 0) {
treesInQueue.removeEldestTree();
} else { // Otherwise look at the next tree in the list
currentActiveNodeIndex++;
}
oldestTree = currentActiveNodeIndex >= treesInQueue.size() ? null : treesInQueue.get(currentActiveNodeIndex);
} else {
// We know that the eldest tree had enough weight to ignore younger trees
// If there isn't enough outgoing evidence, then we poll the next oldest tree for evidence
// This is done to alleviate the problem that the oldest junction tree may have little evidence and drop connectivity
// information better represented by one of the younger trees in the path.
if (totalOut >= weightThreshold) {
return output;
}
}
Expand Down Expand Up @@ -175,7 +162,7 @@ public int getDecisionEdgesTakenSinceLastJunctionTreeEvidence() {
* @param junctionTreeForNode Junction tree to add
*/
public void addJunctionTree(final JunctionTreeLinkedDeBruinGraph.ThreadingTree junctionTreeForNode) {
if (treesInQueue.addJunctionTree(junctionTreeForNode)) {
if (junctionTreeManager.addJunctionTree(junctionTreeForNode)) {
decisionEdgesTakenSinceLastJunctionTreeEvidence = 0;
}
}
Expand Down Expand Up @@ -214,12 +201,17 @@ public boolean addJunctionTree(final JunctionTreeLinkedDeBruinGraph.ThreadingTre

// method to handle incrementing all of the nodes in the tree simultaneously
public void traverseEdgeForAllTrees(E edgeTaken) {
activeNodes = activeNodes.stream().map(node -> {
if (!node.getChildrenNodes().containsKey(edgeTaken)) {
return null;
}
return node.getChildrenNodes().get(edgeTaken);
}).filter(Objects::nonNull).filter(node -> !node.hasNoEvidence()).collect(Collectors.toList());
activeNodes = activeNodes.stream()
.filter(node -> node.getChildrenNodes().containsKey(edgeTaken))
.map(node -> node.getChildrenNodes().get(edgeTaken))
.filter(node -> !node.hasNoEvidence())
.collect(Collectors.toList());
}

// Returns an iterable list of nodes that have sufficient data in the tree (performs pruning of empty nodes)
public Iterable<JunctionTreeLinkedDeBruinGraph.ThreadingNode> removeEmptyNodesAndReturnIterator() {
activeNodes = activeNodes.stream().filter(node -> getTotalOutForBranch(node) > 0).collect(Collectors.toList());
return activeNodes;
}

private JunctionTreeLinkedDeBruinGraph.ThreadingNode get(int i) {
Expand All @@ -230,7 +222,7 @@ private int size() {
return activeNodes == null ? 0 : activeNodes.size();
}

private void removeEldestTree() {
private void removeOldestTree() {
activeNodes.remove(0);
}

Expand Down
Loading

0 comments on commit 7c91e9a

Please sign in to comment.