-
Notifications
You must be signed in to change notification settings - Fork 596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added KernelSegmenter. #3590
Added KernelSegmenter. #3590
Conversation
2ba7ab7
to
5737a46
Compare
5737a46
to
9de9ecc
Compare
@davidbenjamin Feel free to comment if you like, but @mbabadi will handle the main review. Thanks guys! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments but I'm reading for my own benefit, with the review as a side effect.
final List<Integer> windowSizes, | ||
final double numChangepointsPenaltyLinearFactor, | ||
final double numChangepointsPenaltyLogLinearFactor, | ||
final boolean sortByIndex) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider implementing this as a nested enum:
public enum ChangepointSortOrder {
SORT_BY_INDEX, SORT_BY_BACKWARD_SELECTION
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
ParamUtils.isPositiveOrZero(maxNumChangepoints, "Maximum number of changepoints must be non-negative."); | ||
ParamUtils.isPositive(kernelApproximationDimension, "Dimension of kernel approximation must be positive."); | ||
Utils.validateArg(windowSizes.stream().allMatch(ws -> ws > 0), "Window sizes must all be positive."); | ||
Utils.validateArg(new HashSet<>(windowSizes).size() == windowSizes.size(), "Window sizes must all be unique."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
windowSizes.stream().distinct().count() == windowSizes.size()
is slightly wordier but more direct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
final double[] invSqrtSingularValues = Arrays.stream(svd.getSingularValues()).map(Math::sqrt).map(x -> 1. / (x + EPSILON)).toArray(); | ||
@Override | ||
public double visit(int i, int j, double value) { | ||
double sum = 0.; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return new IndexRange(0, numSubsample).sum(k -> kernel.apply(data.get(i), dataSubsample.get(k)) * svd.getU().getEntry(k, j) * invSqrtSingularValues[j])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
//for N x p matrix Z_ij, returns the N-dimensional vector sum(Z_ij * Z_ij, j = 0,..., p - 1), | ||
//which are the diagonal elements K_ii of the approximate kernel matrix | ||
private static double[] calculateKernelApproximationDiagonal(final RealMatrix reducedObservationMatrix) { | ||
return IntStream.range(0, reducedObservationMatrix.getRowDimension()).boxed() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's lighter and shorter to use IndexRange
:
return new IndexRange(0, reducedObservationMatrix.getRowDimension())
.mapToDouble(i -> Arrays.stream(reducedObservationMatrix.getRow(i)).map(z -> z * z).sum());
You could go a step further with
return new IndexRange(0, reducedObservationMatrix.getRowDimension())
.mapToDouble(i -> MathUtils.square(reducedObservationMatrix.getRowVector(i).getNorm()));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
//calculate penalties as a function of the number of changepoints | ||
final int numData = reducedObservationMatrix.getRowDimension(); | ||
final List<Double> changepointPenalties = IntStream.range(0, maxNumChangepoints + 1).boxed() | ||
.map(numChangepoints -> numChangepointsPenaltyLinearFactor * numChangepoints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a one-liner, but I nonetheless think it's worth extracting a method for the changepoint penalty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, done.
.collect(Collectors.toList()); | ||
|
||
//construct initial list of all segments and initialize costs | ||
final List<Integer> candidateStarts = changepointCandidates.stream().sorted().distinct() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about combine starts and ends into a List<IndexRange> candidate segments
, which could be created from a single stream?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Segments as I've defined them are end inclusive, but IndexRange
is supposed to represent an end exclusive interval, so I'll leave things as they are.
final List<Double> costsForMergedSegmentPairs = IntStream.range(0, numSegments - 1).boxed() | ||
.map(i -> new Segment(candidateStarts.get(i), candidateEnds.get(i + 1), reducedObservationMatrix, kernelApproximationDiagonal).cost) | ||
.collect(Collectors.toList()); //cost of each adjacent pair when considered as a single segment | ||
final List<Double> costsForMergingSegmentPairs = IntStream.range(0, numSegments - 1).boxed() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you use double[]
above this reduces to MathArrays.ebeSubtract
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. However, I'm removing elements from the lists, so I'd rather use an ArrayList
(since I like to use arrays to represent fixed-length quantities).
//initialize quantities for recurrence | ||
double D = kernelApproximationDiagonal[start]; | ||
final double[] W = Arrays.copyOf(reducedObservationMatrix.getRow(start), p); | ||
double V = Arrays.stream(W).map(w -> w * w).sum(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth extracting a MathUtils
for sum of squares of a double[]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where are the boy scouts of yesteryear...
//update quantities in left segment | ||
leftD -= kernelApproximationDiagonal[start]; | ||
ZdotW = 0.; | ||
for (int j = 0; j < p; j++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you extract a method for this four-line motif?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about it, but decided against it because such a method needs to either return ZdotW
and modify the W
costs as a side effect or return a pair of quantities. There are also enough variations on the motif to require parameterization of the method. But it does shorten up the code a bit, so done!
Whoops, just realized I broke the tests with some changes in a previous PR. Will fix and address @davidbenjamin's comments, thanks! |
Codecov Report
@@ Coverage Diff @@
## master #3590 +/- ##
===============================================
+ Coverage 79.736% 80.094% +0.358%
- Complexity 18148 18799 +651
===============================================
Files 1217 1226 +9
Lines 66602 69015 +2413
Branches 10429 11073 +644
===============================================
+ Hits 53106 55277 +2171
- Misses 9289 9415 +126
- Partials 4207 4323 +116
|
I'm satisfied. |
@mbabadi Can you review? I need to carry changes from this PR forward in the rest of the branch. Thanks! |
I'm halfway through, will finish today :)
…On Sep 26, 2017 9:14 AM, "samuelklee" ***@***.***> wrote:
@mbabadi <https://github.com/mbabadi> Can you review? I need to carry
changes from this PR forward in the rest of the branch. Thanks!
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#3590 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AOmMjb0_wGN9Bf1zIK2J3BmJybK7WmiAks5smPbkgaJpZM4PcxTR>
.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great PR! the code is well-organized and well-written :)
* This gives the global cost as a function of the number of changepoints <i>C</i>. | ||
* </ol> | ||
* <ol> | ||
* 6) Add a penalty <i>A * C + B * C * log N / C</i> to the global cost and find the minimum to determine the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log N / C
-> log(N / C)
for clarity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, perhaps add that T
is the generic data type, and that this class works for segmenting generic sequential data so long as kernel (T, T) -> R
is provided.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you forgot to change all instances of T
to DATA
.
|
||
private final List<T> data; | ||
|
||
public KernelSegmenter(final List<T> data) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about making KernelSegmeter
a fully static class? use a private constructor to prevent instantiation; data
is only used in findChangepoints
. It can have a public static
declaration and take data
as an additional argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I was allowing for the possibility of passing the data once and finding multiple sets of changepoints. This could also be done with a static method, I guess...not sure if there are any advantages/disadvantages either way, so I'll leave it as is, if you don't mind!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The advantage of the static implementation would be a small saving in instantiation space/time cost, but that's negligible for our current use cases; so I'm satisfied.
final boolean sortByIndex) { | ||
ParamUtils.isPositiveOrZero(maxNumChangepoints, "Maximum number of changepoints must be non-negative."); | ||
ParamUtils.isPositive(kernelApproximationDimension, "Dimension of kernel approximation must be positive."); | ||
Utils.validateArg(windowSizes.stream().allMatch(ws -> ws > 0), "Window sizes must all be positive."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also assert !windowSizes.isEmpty()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
//perform SVD of kernel matrix of subsampled data | ||
logger.info(String.format("Performing SVD of kernel matrix of subsampled data (%d x %d)...", numSubsample, numSubsample)); | ||
final SingularValueDecomposition svd = new SingularValueDecomposition(subKernelMatrix); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the SVD implementation in ojAlgo is substantially faster than Apache's. Also, you may want to consider using sgesvd
from nd4j (http://nd4j.org/doc/org/nd4j/linalg/api/blas/Lapack.html) for more speed, if SVD is a computational bottleneck in your algorithm. The problem is that you are using Apache matrices so this solution requires adaptors between RealMatrix
and INDArray
. By the way, the garbage generation issues of nd4j has been solved in nd4j 0.9.0 using workspaces (see https://deeplearning4j.org/workspaces)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I think the bottleneck turned out to be the original non-matrix-multiplication implementation of this method.
I'll keep using Apache's SVD for now, but might consider implementing randomized SVD and using that here in the future.
for (final int tauPrime : indices) { | ||
D += kernelApproximationDiagonal[tauPrime]; | ||
final double ZdotW = calculateZdotWAndModifyW(reducedObservationMatrix, p, W, tauPrime, 1); | ||
V += 2. * ZdotW + kernelApproximationDiagonal[tauPrime]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2.
-> 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like making my doubles explicit...just a (bad?) habit!
changepointCandidates.addAll(windowCostLocalMinima.subList(0, Math.min(maxNumChangepoints, windowCostLocalMinima.size()))); | ||
} | ||
|
||
if (changepointCandidates.isEmpty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Theoretically speaking, under what conditions changepointCandidates
would end up being empty? is it always user error (bad window size, rank of the reduced data matrix being too low, empty data matrix, etc?). I don't understand the nuances of kernel segmentation, but I have a feeling that changepointCandidates
should be non-empty under fairly lax conditions. Could you isolate the edge cases and throw more informative exceptions at early stages such that you would never encounter changepointCandidates.isEmpty()
at this stage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good point. I think if you have non-empty data, you are always guaranteed a local minimum from PersistenceOptimizer, so I guess this check isn't necessary. I'll add a test for empty data and decide on the appropriate behavior in that case.
As for the unlikely case of a low-rank reduced data matrix or a bad SVD result, I think I'll cross my fingers and hope that appropriate exceptions are thrown from the SVD code...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decided to add a warning and return an empty list in the case of no data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but we still have UserException.BadInput
?
.collect(Collectors.toList()); | ||
final int numChangepointsOptimal = totalSegmentationCostsPlusPenalties.indexOf(Collections.min(totalSegmentationCostsPlusPenalties)); | ||
|
||
logger.info(String.format("Found %d changepoints after applying penalties.", numChangepointsOptimal)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Found %d changepoints after applying the segmentation penalty."
return windowCosts; | ||
} | ||
|
||
private static double calculateZdotWAndModifyW(final RealMatrix reducedObservationMatrix, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a brief description of the signMultiplier
, and what does modifying W
means?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my opinion, extracting calculateZdotWAndModifyW
has made the code less readable :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I think I agree...I'll change it back if @davidbenjamin doesn't mind. (snip snap snip snap!)
* @author Samuel Lee <slee@broadinstitute.org> | ||
*/ | ||
public final class KernelSegmenterUnitTest extends BaseTest { | ||
private static final int RANDOM_SEED = 1; //reset seed before each simulated test case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
//random seed is reset to this value before each simulated test case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was meant to be an imperative to future developers...I'll make it more explicit (//make sure to reset random seed to this value before each simulated test case
)
public final class KernelSegmenterUnitTest extends BaseTest { | ||
private static final int RANDOM_SEED = 1; //reset seed before each simulated test case | ||
|
||
@DataProvider(name = "dataKernelSegmenter") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you briefly describe the two simulated test data in the unit test class javadoc (e.g. the first case is Gaussian data with discrete jumps in the mean every 100 points; the second case simulates a multimodal series such that the mean alternates in sign successively while the absolute value of the mean jumps every 100 points).
I would also include a few corner cases (empty data, empty window size, single-segment simulated data), and perhaps an example for a small but obvious changepoint + long wavelength sinusoidal background and how including small window sizes help detect it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good points. Added everything except the sinusoidal background case. Although useful, I think it's a bit beyond the scope of the unit test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
coughlazycough OK! ;-)
* This gives the global cost as a function of the number of changepoints <i>C</i>. | ||
* </ol> | ||
* <ol> | ||
* 6) Add a penalty <i>A * C + B * C * log N / C</i> to the global cost and find the minimum to determine the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
final boolean sortByIndex) { | ||
ParamUtils.isPositiveOrZero(maxNumChangepoints, "Maximum number of changepoints must be non-negative."); | ||
ParamUtils.isPositive(kernelApproximationDimension, "Dimension of kernel approximation must be positive."); | ||
Utils.validateArg(windowSizes.stream().allMatch(ws -> ws > 0), "Window sizes must all be positive."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
//perform SVD of kernel matrix of subsampled data | ||
logger.info(String.format("Performing SVD of kernel matrix of subsampled data (%d x %d)...", numSubsample, numSubsample)); | ||
final SingularValueDecomposition svd = new SingularValueDecomposition(subKernelMatrix); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I think the bottleneck turned out to be the original non-matrix-multiplication implementation of this method.
I'll keep using Apache's SVD for now, but might consider implementing randomized SVD and using that here in the future.
changepointCandidates.addAll(windowCostLocalMinima.subList(0, Math.min(maxNumChangepoints, windowCostLocalMinima.size()))); | ||
} | ||
|
||
if (changepointCandidates.isEmpty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good point. I think if you have non-empty data, you are always guaranteed a local minimum from PersistenceOptimizer, so I guess this check isn't necessary. I'll add a test for empty data and decide on the appropriate behavior in that case.
As for the unlikely case of a low-rank reduced data matrix or a bad SVD result, I think I'll cross my fingers and hope that appropriate exceptions are thrown from the SVD code...
.collect(Collectors.toList()); | ||
final int numChangepointsOptimal = totalSegmentationCostsPlusPenalties.indexOf(Collections.min(totalSegmentationCostsPlusPenalties)); | ||
|
||
logger.info(String.format("Found %d changepoints after applying penalties.", numChangepointsOptimal)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done (although I called it a changepoint penalty
).
* @author Samuel Lee <slee@broadinstitute.org> | ||
*/ | ||
public final class KernelSegmenterUnitTest extends BaseTest { | ||
private static final int RANDOM_SEED = 1; //reset seed before each simulated test case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was meant to be an imperative to future developers...I'll make it more explicit (//make sure to reset random seed to this value before each simulated test case
)
public final class KernelSegmenterUnitTest extends BaseTest { | ||
private static final int RANDOM_SEED = 1; //reset seed before each simulated test case | ||
|
||
@DataProvider(name = "dataKernelSegmenter") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good points. Added everything except the sinusoidal background case. Although useful, I think it's a bit beyond the scope of the unit test.
for (final int tauPrime : indices) { | ||
D += kernelApproximationDiagonal[tauPrime]; | ||
final double ZdotW = calculateZdotWAndModifyW(reducedObservationMatrix, p, W, tauPrime, 1); | ||
V += 2. * ZdotW + kernelApproximationDiagonal[tauPrime]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like making my doubles explicit...just a (bad?) habit!
changepointCandidates.addAll(windowCostLocalMinima.subList(0, Math.min(maxNumChangepoints, windowCostLocalMinima.size()))); | ||
} | ||
|
||
if (changepointCandidates.isEmpty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decided to add a warning and return an empty list in the case of no data.
|
||
private final List<T> data; | ||
|
||
public KernelSegmenter(final List<T> data) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I was allowing for the possibility of passing the data once and finding multiple sets of changepoints. This could also be done with a static method, I guess...not sure if there are any advantages/disadvantages either way, so I'll leave it as is, if you don't mind!
I am satisfied :) |
Thanks! Fixed up some inadvertent missed changes. Will squash and merge once tests pass, if @davidbenjamin gives the go ahead! |
@samuelklee I sure do. |
Some of the posts in #2858 might be helpful in reviewing this PR.