Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added KernelSegmenter. #3590

Merged
merged 8 commits into from
Sep 26, 2017
Merged

Added KernelSegmenter. #3590

merged 8 commits into from
Sep 26, 2017

Conversation

samuelklee
Copy link
Contributor

Some of the posts in #2858 might be helpful in reviewing this PR.

@samuelklee
Copy link
Contributor Author

@davidbenjamin Feel free to comment if you like, but @mbabadi will handle the main review. Thanks guys!

Copy link
Contributor

@davidbenjamin davidbenjamin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments but I'm reading for my own benefit, with the review as a side effect.

final List<Integer> windowSizes,
final double numChangepointsPenaltyLinearFactor,
final double numChangepointsPenaltyLogLinearFactor,
final boolean sortByIndex) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider implementing this as a nested enum:

public enum ChangepointSortOrder {
  SORT_BY_INDEX, SORT_BY_BACKWARD_SELECTION
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ParamUtils.isPositiveOrZero(maxNumChangepoints, "Maximum number of changepoints must be non-negative.");
ParamUtils.isPositive(kernelApproximationDimension, "Dimension of kernel approximation must be positive.");
Utils.validateArg(windowSizes.stream().allMatch(ws -> ws > 0), "Window sizes must all be positive.");
Utils.validateArg(new HashSet<>(windowSizes).size() == windowSizes.size(), "Window sizes must all be unique.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

windowSizes.stream().distinct().count() == windowSizes.size() is slightly wordier but more direct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

final double[] invSqrtSingularValues = Arrays.stream(svd.getSingularValues()).map(Math::sqrt).map(x -> 1. / (x + EPSILON)).toArray();
@Override
public double visit(int i, int j, double value) {
double sum = 0.;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return new IndexRange(0, numSubsample).sum(k -> kernel.apply(data.get(i), dataSubsample.get(k)) * svd.getU().getEntry(k, j) * invSqrtSingularValues[j])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

//for N x p matrix Z_ij, returns the N-dimensional vector sum(Z_ij * Z_ij, j = 0,..., p - 1),
//which are the diagonal elements K_ii of the approximate kernel matrix
private static double[] calculateKernelApproximationDiagonal(final RealMatrix reducedObservationMatrix) {
return IntStream.range(0, reducedObservationMatrix.getRowDimension()).boxed()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's lighter and shorter to use IndexRange:

return new IndexRange(0, reducedObservationMatrix.getRowDimension())
   .mapToDouble(i -> Arrays.stream(reducedObservationMatrix.getRow(i)).map(z -> z * z).sum());

You could go a step further with

return new IndexRange(0, reducedObservationMatrix.getRowDimension())
   .mapToDouble(i -> MathUtils.square(reducedObservationMatrix.getRowVector(i).getNorm()));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

//calculate penalties as a function of the number of changepoints
final int numData = reducedObservationMatrix.getRowDimension();
final List<Double> changepointPenalties = IntStream.range(0, maxNumChangepoints + 1).boxed()
.map(numChangepoints -> numChangepointsPenaltyLinearFactor * numChangepoints
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a one-liner, but I nonetheless think it's worth extracting a method for the changepoint penalty.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, done.

.collect(Collectors.toList());

//construct initial list of all segments and initialize costs
final List<Integer> candidateStarts = changepointCandidates.stream().sorted().distinct()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about combine starts and ends into a List<IndexRange> candidate segments, which could be created from a single stream?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Segments as I've defined them are end inclusive, but IndexRange is supposed to represent an end exclusive interval, so I'll leave things as they are.

final List<Double> costsForMergedSegmentPairs = IntStream.range(0, numSegments - 1).boxed()
.map(i -> new Segment(candidateStarts.get(i), candidateEnds.get(i + 1), reducedObservationMatrix, kernelApproximationDiagonal).cost)
.collect(Collectors.toList()); //cost of each adjacent pair when considered as a single segment
final List<Double> costsForMergingSegmentPairs = IntStream.range(0, numSegments - 1).boxed()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you use double[] above this reduces to MathArrays.ebeSubtract.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. However, I'm removing elements from the lists, so I'd rather use an ArrayList (since I like to use arrays to represent fixed-length quantities).

//initialize quantities for recurrence
double D = kernelApproximationDiagonal[start];
final double[] W = Arrays.copyOf(reducedObservationMatrix.getRow(start), p);
double V = Arrays.stream(W).map(w -> w * w).sum();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth extracting a MathUtils for sum of squares of a double[].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are the boy scouts of yesteryear...

//update quantities in left segment
leftD -= kernelApproximationDiagonal[start];
ZdotW = 0.;
for (int j = 0; j < p; j++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you extract a method for this four-line motif?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it, but decided against it because such a method needs to either return ZdotW and modify the W costs as a side effect or return a pair of quantities. There are also enough variations on the motif to require parameterization of the method. But it does shorten up the code a bit, so done!

@samuelklee
Copy link
Contributor Author

Whoops, just realized I broke the tests with some changes in a previous PR. Will fix and address @davidbenjamin's comments, thanks!

@codecov-io
Copy link

codecov-io commented Sep 20, 2017

Codecov Report

Merging #3590 into master will increase coverage by 0.358%.
The diff coverage is 93.939%.

@@               Coverage Diff               @@
##              master     #3590       +/-   ##
===============================================
+ Coverage     79.736%   80.094%   +0.358%     
- Complexity     18148     18799      +651     
===============================================
  Files           1217      1226        +9     
  Lines          66602     69015     +2413     
  Branches       10429     11073      +644     
===============================================
+ Hits           53106     55277     +2171     
- Misses          9289      9415      +126     
- Partials        4207      4323      +116
Impacted Files Coverage Δ Complexity Δ
...umber/utils/optimization/PersistenceOptimizer.java 84.946% <ø> (ø) 27 <0> (ø) ⬇️
...copynumber/utils/segmentation/KernelSegmenter.java 93.939% <93.939%> (ø) 44 <44> (?)
.../tools/spark/sv/evidence/QNamesForKmersFinder.java 83.333% <0%> (-16.667%) 7% <0%> (ø)
...nder/tools/spark/pathseq/PSPathogenTaxonScore.java 78.125% <0%> (-10.11%) 20% <0%> (+17%)
.../copynumber/allelic/alleliccount/AllelicCount.java 65.854% <0%> (-2.003%) 19% <0%> (+6%)
...te/hellbender/tools/spark/sv/utils/SVKmerizer.java 85.714% <0%> (-1.786%) 18% <0%> (ø)
...ools/spark/pathseq/PSFilterArgumentCollection.java 79.487% <0%> (-1.158%) 3% <0%> (+1%)
...ecaller/AssemblyBasedCallerArgumentCollection.java 100% <0%> (ø) 2% <0%> (+1%) ⬆️
...oadinstitute/hellbender/utils/gcs/BucketUtils.java 78.571% <0%> (ø) 39% <0%> (ø) ⬇️
...tools/spark/pathseq/PSScoreArgumentCollection.java 100% <0%> (ø) 1% <0%> (ø) ⬇️
... and 37 more

@davidbenjamin
Copy link
Contributor

I'm satisfied.

@samuelklee
Copy link
Contributor Author

@mbabadi Can you review? I need to carry changes from this PR forward in the rest of the branch. Thanks!

@mbabadi
Copy link
Contributor

mbabadi commented Sep 26, 2017 via email

Copy link
Contributor

@mbabadi mbabadi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR! the code is well-organized and well-written :)

* This gives the global cost as a function of the number of changepoints <i>C</i>.
* </ol>
* <ol>
* 6) Add a penalty <i>A * C + B * C * log N / C</i> to the global cost and find the minimum to determine the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log N / C -> log(N / C) for clarity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, perhaps add that T is the generic data type, and that this class works for segmenting generic sequential data so long as kernel (T, T) -> R is provided.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you forgot to change all instances of T to DATA.


private final List<T> data;

public KernelSegmenter(final List<T> data) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about making KernelSegmeter a fully static class? use a private constructor to prevent instantiation; data is only used in findChangepoints. It can have a public static declaration and take data as an additional argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was allowing for the possibility of passing the data once and finding multiple sets of changepoints. This could also be done with a static method, I guess...not sure if there are any advantages/disadvantages either way, so I'll leave it as is, if you don't mind!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The advantage of the static implementation would be a small saving in instantiation space/time cost, but that's negligible for our current use cases; so I'm satisfied.

final boolean sortByIndex) {
ParamUtils.isPositiveOrZero(maxNumChangepoints, "Maximum number of changepoints must be non-negative.");
ParamUtils.isPositive(kernelApproximationDimension, "Dimension of kernel approximation must be positive.");
Utils.validateArg(windowSizes.stream().allMatch(ws -> ws > 0), "Window sizes must all be positive.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also assert !windowSizes.isEmpty().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


//perform SVD of kernel matrix of subsampled data
logger.info(String.format("Performing SVD of kernel matrix of subsampled data (%d x %d)...", numSubsample, numSubsample));
final SingularValueDecomposition svd = new SingularValueDecomposition(subKernelMatrix);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the SVD implementation in ojAlgo is substantially faster than Apache's. Also, you may want to consider using sgesvd from nd4j (http://nd4j.org/doc/org/nd4j/linalg/api/blas/Lapack.html) for more speed, if SVD is a computational bottleneck in your algorithm. The problem is that you are using Apache matrices so this solution requires adaptors between RealMatrix and INDArray. By the way, the garbage generation issues of nd4j has been solved in nd4j 0.9.0 using workspaces (see https://deeplearning4j.org/workspaces)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think the bottleneck turned out to be the original non-matrix-multiplication implementation of this method.

I'll keep using Apache's SVD for now, but might consider implementing randomized SVD and using that here in the future.

for (final int tauPrime : indices) {
D += kernelApproximationDiagonal[tauPrime];
final double ZdotW = calculateZdotWAndModifyW(reducedObservationMatrix, p, W, tauPrime, 1);
V += 2. * ZdotW + kernelApproximationDiagonal[tauPrime];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2. -> 2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like making my doubles explicit...just a (bad?) habit!

changepointCandidates.addAll(windowCostLocalMinima.subList(0, Math.min(maxNumChangepoints, windowCostLocalMinima.size())));
}

if (changepointCandidates.isEmpty()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theoretically speaking, under what conditions changepointCandidates would end up being empty? is it always user error (bad window size, rank of the reduced data matrix being too low, empty data matrix, etc?). I don't understand the nuances of kernel segmentation, but I have a feeling that changepointCandidates should be non-empty under fairly lax conditions. Could you isolate the edge cases and throw more informative exceptions at early stages such that you would never encounter changepointCandidates.isEmpty() at this stage?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point. I think if you have non-empty data, you are always guaranteed a local minimum from PersistenceOptimizer, so I guess this check isn't necessary. I'll add a test for empty data and decide on the appropriate behavior in that case.

As for the unlikely case of a low-rank reduced data matrix or a bad SVD result, I think I'll cross my fingers and hope that appropriate exceptions are thrown from the SVD code...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decided to add a warning and return an empty list in the case of no data.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we still have UserException.BadInput?

.collect(Collectors.toList());
final int numChangepointsOptimal = totalSegmentationCostsPlusPenalties.indexOf(Collections.min(totalSegmentationCostsPlusPenalties));

logger.info(String.format("Found %d changepoints after applying penalties.", numChangepointsOptimal));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Found %d changepoints after applying the segmentation penalty."

return windowCosts;
}

private static double calculateZdotWAndModifyW(final RealMatrix reducedObservationMatrix,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a brief description of the signMultiplier, and what does modifying W means?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, extracting calculateZdotWAndModifyW has made the code less readable :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think I agree...I'll change it back if @davidbenjamin doesn't mind. (snip snap snip snap!)

* @author Samuel Lee &lt;slee@broadinstitute.org&gt;
*/
public final class KernelSegmenterUnitTest extends BaseTest {
private static final int RANDOM_SEED = 1; //reset seed before each simulated test case
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

//random seed is reset to this value before each simulated test case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was meant to be an imperative to future developers...I'll make it more explicit (//make sure to reset random seed to this value before each simulated test case)

public final class KernelSegmenterUnitTest extends BaseTest {
private static final int RANDOM_SEED = 1; //reset seed before each simulated test case

@DataProvider(name = "dataKernelSegmenter")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you briefly describe the two simulated test data in the unit test class javadoc (e.g. the first case is Gaussian data with discrete jumps in the mean every 100 points; the second case simulates a multimodal series such that the mean alternates in sign successively while the absolute value of the mean jumps every 100 points).

I would also include a few corner cases (empty data, empty window size, single-segment simulated data), and perhaps an example for a small but obvious changepoint + long wavelength sinusoidal background and how including small window sizes help detect it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points. Added everything except the sinusoidal background case. Although useful, I think it's a bit beyond the scope of the unit test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

coughlazycough OK! ;-)

* This gives the global cost as a function of the number of changepoints <i>C</i>.
* </ol>
* <ol>
* 6) Add a penalty <i>A * C + B * C * log N / C</i> to the global cost and find the minimum to determine the
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

final boolean sortByIndex) {
ParamUtils.isPositiveOrZero(maxNumChangepoints, "Maximum number of changepoints must be non-negative.");
ParamUtils.isPositive(kernelApproximationDimension, "Dimension of kernel approximation must be positive.");
Utils.validateArg(windowSizes.stream().allMatch(ws -> ws > 0), "Window sizes must all be positive.");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


//perform SVD of kernel matrix of subsampled data
logger.info(String.format("Performing SVD of kernel matrix of subsampled data (%d x %d)...", numSubsample, numSubsample));
final SingularValueDecomposition svd = new SingularValueDecomposition(subKernelMatrix);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think the bottleneck turned out to be the original non-matrix-multiplication implementation of this method.

I'll keep using Apache's SVD for now, but might consider implementing randomized SVD and using that here in the future.

changepointCandidates.addAll(windowCostLocalMinima.subList(0, Math.min(maxNumChangepoints, windowCostLocalMinima.size())));
}

if (changepointCandidates.isEmpty()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point. I think if you have non-empty data, you are always guaranteed a local minimum from PersistenceOptimizer, so I guess this check isn't necessary. I'll add a test for empty data and decide on the appropriate behavior in that case.

As for the unlikely case of a low-rank reduced data matrix or a bad SVD result, I think I'll cross my fingers and hope that appropriate exceptions are thrown from the SVD code...

.collect(Collectors.toList());
final int numChangepointsOptimal = totalSegmentationCostsPlusPenalties.indexOf(Collections.min(totalSegmentationCostsPlusPenalties));

logger.info(String.format("Found %d changepoints after applying penalties.", numChangepointsOptimal));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done (although I called it a changepoint penalty).

* @author Samuel Lee &lt;slee@broadinstitute.org&gt;
*/
public final class KernelSegmenterUnitTest extends BaseTest {
private static final int RANDOM_SEED = 1; //reset seed before each simulated test case
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was meant to be an imperative to future developers...I'll make it more explicit (//make sure to reset random seed to this value before each simulated test case)

public final class KernelSegmenterUnitTest extends BaseTest {
private static final int RANDOM_SEED = 1; //reset seed before each simulated test case

@DataProvider(name = "dataKernelSegmenter")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points. Added everything except the sinusoidal background case. Although useful, I think it's a bit beyond the scope of the unit test.

for (final int tauPrime : indices) {
D += kernelApproximationDiagonal[tauPrime];
final double ZdotW = calculateZdotWAndModifyW(reducedObservationMatrix, p, W, tauPrime, 1);
V += 2. * ZdotW + kernelApproximationDiagonal[tauPrime];
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like making my doubles explicit...just a (bad?) habit!

changepointCandidates.addAll(windowCostLocalMinima.subList(0, Math.min(maxNumChangepoints, windowCostLocalMinima.size())));
}

if (changepointCandidates.isEmpty()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decided to add a warning and return an empty list in the case of no data.


private final List<T> data;

public KernelSegmenter(final List<T> data) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was allowing for the possibility of passing the data once and finding multiple sets of changepoints. This could also be done with a static method, I guess...not sure if there are any advantages/disadvantages either way, so I'll leave it as is, if you don't mind!

@mbabadi
Copy link
Contributor

mbabadi commented Sep 26, 2017

I am satisfied :)

@samuelklee
Copy link
Contributor Author

Thanks! Fixed up some inadvertent missed changes. Will squash and merge once tests pass, if @davidbenjamin gives the go ahead!

@davidbenjamin
Copy link
Contributor

@samuelklee I sure do.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants