Skip to content

Commit

Permalink
Merge pull request #5 from sa501428/optimize_speed
Browse files Browse the repository at this point in the history
Optimize speed
  • Loading branch information
sa501428 authored Apr 16, 2022
2 parents 752fe45 + d383726 commit 43fad5e
Show file tree
Hide file tree
Showing 14 changed files with 263 additions and 240 deletions.
6 changes: 3 additions & 3 deletions src/javastraw/StrawGlobals.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
package javastraw;

public class StrawGlobals {
public static final String versionNum = "2.11.04";
// min hic file version supported
public static final String versionNum = "2.14.03";
public static final int minVersion = 6;
public static final int bufferSize = 2097152;

// implement Map scaling with this global variable
public static boolean allowDynamicBlockIndex = true;
public static int dynamicResolutionLimit = 50;

public static boolean printVerboseComments = false;
}
8 changes: 6 additions & 2 deletions src/javastraw/reader/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ public void clearCache(boolean onlyClearInter) {
}
eigenvectorCache.clear();
normalizationVectorCache.clear();
normalizationTypes.clear();
matrices.clear();
}

public Matrix getMatrix(Chromosome chr1, Chromosome chr2) {
return getMatrix(chr1, chr2, -1);
}

public Matrix getMatrix(Chromosome chr1, Chromosome chr2, int specificResolution) {

// order is arbitrary, convention is lower # chr first
if (chr1 == null || chr2 == null) return null;
Expand All @@ -97,7 +101,7 @@ public Matrix getMatrix(Chromosome chr1, Chromosome chr2) {

if (m == null && reader != null) {
try {
m = reader.readMatrix(key);
m = reader.readMatrix(key, specificResolution);
matrices.put(key, m);
} catch (Exception e) {
System.err.println("Error fetching matrix for: " + chr1.getName() + "-" + chr2.getName());
Expand Down
9 changes: 4 additions & 5 deletions src/javastraw/reader/DatasetReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
package javastraw.reader;

import javastraw.reader.block.Block;
import javastraw.reader.block.IndexEntry;
import javastraw.reader.datastructures.ListOfDoubleArrays;
import javastraw.reader.norm.NormalizationVector;
import javastraw.reader.type.HiCZoom;
import javastraw.reader.type.NormalizationType;

import java.io.IOException;
import java.util.List;

public interface DatasetReader {

Expand All @@ -43,12 +43,11 @@ public interface DatasetReader {

Dataset read() throws IOException;

Matrix readMatrix(String key) throws IOException;
Matrix readMatrix(String key, int resolution) throws IOException;

Block readNormalizedBlock(int blockNumber, String zdKey, NormalizationType no,
int chr1Index, int chr2Index, HiCZoom zoom) throws IOException;

List<Integer> getBlockNumbers(String zdKey);
int chr1Index, int chr2Index, HiCZoom zoom,
IndexEntry idx) throws IOException;

NormalizationVector readNormalizationVector(NormalizationType type, int chrIdx, HiCZoom.HiCUnit unit, int binSize) throws IOException;

Expand Down
239 changes: 113 additions & 126 deletions src/javastraw/reader/DatasetReaderV2.java

Large diffs are not rendered by default.

88 changes: 52 additions & 36 deletions src/javastraw/reader/ReaderTools.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,20 @@ public class ReaderTools {
private static final int maxLengthEntryName = 100;
private static final int MAX_BYTE_READ_SIZE = Integer.MAX_VALUE - 10;

static SeekableStream getValidStream(String path) throws IOException {
public static SeekableStream getValidStream(String path) throws IOException {
SeekableStream stream;
do {
stream = streamFactory.getStreamFor(path);
} while (stream == null);
return stream;
}

public static SeekableStream getValidStream(String path, long position) throws IOException {
SeekableStream stream = getValidStream(path);
stream.seek(position);
return stream;
}


static LittleEndianInputStream createStreamFromSeveralBuffers(LargeIndexEntry idx, String path) throws IOException {
List<byte[]> buffer = seekAndFullyReadLargeCompressedBytes(idx, path);
Expand All @@ -52,8 +58,7 @@ static LittleEndianInputStream createStreamFromSeveralBuffers(LargeIndexEntry id

static byte[] seekAndFullyReadCompressedBytes(IndexEntry idx, String path) throws IOException {
byte[] compressedBytes = new byte[idx.size];
SeekableStream stream = ReaderTools.getValidStream(path);
stream.seek(idx.position);
SeekableStream stream = ReaderTools.getValidStream(path, idx.position);
stream.readFully(compressedBytes);
stream.close();
return compressedBytes;
Expand All @@ -68,8 +73,7 @@ static List<byte[]> seekAndFullyReadLargeCompressedBytes(LargeIndexEntry idx, St
}
compressedBytes.add(new byte[(int) counter]);

SeekableStream stream = ReaderTools.getValidStream(path);
stream.seek(idx.position);
SeekableStream stream = ReaderTools.getValidStream(path, idx.position);
for (int i = 0; i < compressedBytes.size(); i++) {
stream.readFully(compressedBytes.get(i));
}
Expand All @@ -79,10 +83,8 @@ static List<byte[]> seekAndFullyReadLargeCompressedBytes(LargeIndexEntry idx, St

static Pair<MatrixZoomData, Long> readMatrixZoomData(Chromosome chr1, Chromosome chr2, int[] chr1Sites, int[] chr2Sites,
long filePointer, String path, boolean useCache,
Map<String, BlockIndex> blockIndexMap,
DatasetReader reader) throws IOException {
SeekableStream stream = ReaderTools.getValidStream(path);
stream.seek(filePointer);
DatasetReader reader, int specificResolution) throws IOException {
SeekableStream stream = ReaderTools.getValidStream(path, filePointer);
LittleEndianInputStream dis = new LittleEndianInputStream(new BufferedInputStream(stream, StrawGlobals.bufferSize));

String hicUnitStr = dis.readString();
Expand All @@ -104,29 +106,31 @@ static Pair<MatrixZoomData, Long> readMatrixZoomData(Chromosome chr1, Chromosome
int blockBinCount = dis.readInt();
int blockColumnCount = dis.readInt();

MatrixZoomData zd = new MatrixZoomData(chr1, chr2, zoom, blockBinCount, blockColumnCount, chr1Sites, chr2Sites,
reader);
zd.setUseCache(useCache);

int nBlocks = dis.readInt();

BlockIndices blockIndices;
long currentFilePointer = filePointer + (9 * 4) + hicUnitStr.getBytes().length + 1; // i think 1 byte for 0 terminated string?

if (binSize < 50 && StrawGlobals.allowDynamicBlockIndex) {
int maxPossibleBlockNumber = blockColumnCount * blockColumnCount - 1;
DynamicBlockIndex blockIndex = new DynamicBlockIndex(ReaderTools.getValidStream(path), nBlocks, maxPossibleBlockNumber, currentFilePointer);
blockIndexMap.put(zd.getKey(), blockIndex);
if (specificResolution > 0) {
if (binSize != specificResolution) {
int maxPossibleBlockNumber = blockColumnCount * blockColumnCount - 1;
blockIndices = new DynamicBlockIndices(ReaderTools.getValidStream(path), nBlocks, maxPossibleBlockNumber, currentFilePointer);
} else {
blockIndices = new BlockIndices(nBlocks);
blockIndices.populateBlocks(dis);
}
} else {
BlockIndex blockIndex = new BlockIndex(nBlocks);
blockIndex.populateBlocks(dis);
blockIndexMap.put(zd.getKey(), blockIndex);
if (binSize < StrawGlobals.dynamicResolutionLimit && StrawGlobals.allowDynamicBlockIndex) {
int maxPossibleBlockNumber = blockColumnCount * blockColumnCount - 1;
blockIndices = new DynamicBlockIndices(ReaderTools.getValidStream(path), nBlocks, maxPossibleBlockNumber, currentFilePointer);
} else {
blockIndices = new BlockIndices(nBlocks);
blockIndices.populateBlocks(dis);
}
}
currentFilePointer += (nBlocks * 16L);

long nBins1 = chr1.getLength() / binSize;
long nBins2 = chr2.getLength() / binSize;
double avgCount = (sumCounts / nBins1) / nBins2; // <= trying to avoid overflows
zd.setAverageCount(avgCount);
MatrixZoomData zd = new MatrixZoomData(chr1, chr2, zoom, blockBinCount, blockColumnCount, chr1Sites, chr2Sites,
reader, blockIndices, useCache, sumCounts);

stream.close();
return new Pair<>(zd, currentFilePointer);
Expand All @@ -135,9 +139,8 @@ static Pair<MatrixZoomData, Long> readMatrixZoomData(Chromosome chr1, Chromosome
static long readExpectedVectorInFooter(long currentPosition,
Map<String, ExpectedValueFunction> expectedValuesMap,
NormalizationType norm, int version, String path, DatasetReader reader) throws IOException {
SeekableStream stream = ReaderTools.getValidStream(path);
stream.seek(currentPosition);
LittleEndianInputStream dis = new LittleEndianInputStream(new BufferedInputStream(stream, StrawGlobals.bufferSize));
SeekableStream stream = ReaderTools.getValidStream(path, currentPosition);
LittleEndianInputStream dis = new LittleEndianInputStream(new BufferedInputStream(stream, 50));
String unitString = dis.readString();
currentPosition += (unitString.length() + 1);
HiCZoom.HiCUnit unit = HiCZoom.valueOfUnit(unitString);
Expand All @@ -146,14 +149,17 @@ static long readExpectedVectorInFooter(long currentPosition,

long[] nValues = new long[1];
currentPosition += readVectorLength(dis, nValues, version);

/* todo time
if (binSize >= 500) {
currentPosition = ReaderTools.readWholeNormalizationVector(currentPosition, dis, expectedValuesMap, unit, binSize,
nValues[0], norm, version);
} else {
currentPosition = ReaderTools.setUpPartialVectorStreaming(currentPosition, expectedValuesMap, unit, binSize,
nValues[0], norm, version, path, reader);
}
*/
currentPosition = ReaderTools.setUpPartialVectorStreaming(currentPosition, expectedValuesMap, unit, binSize,
nValues[0], norm, version, path, reader);
stream.close();
return currentPosition;
}
Expand All @@ -179,14 +185,11 @@ static long setUpPartialVectorStreaming(long currentPosition, Map<String, Expect
skipPosition += (nValues * 8);
}

SeekableStream stream = ReaderTools.getValidStream(path);
stream.seek(skipPosition);
LittleEndianInputStream dis = new LittleEndianInputStream(new BufferedInputStream(stream, StrawGlobals.bufferSize));
//long skipPosition = stream.position();
int nNormalizationFactors = dis.readInt();
SeekableStream stream = ReaderTools.getValidStream(path, skipPosition);
int nNormalizationFactors = ReaderTools.readIntFromBytes(stream);
currentPosition = skipPosition + 4;

NormFactorMapReader hmReader = new NormFactorMapReader(nNormalizationFactors, version, dis);
NormFactorMapReader hmReader = new NormFactorMapReader(nNormalizationFactors, version, currentPosition, path);
currentPosition += hmReader.getOffset();

ExpectedValueFunction df = new ExpectedValueFunctionImpl(norm, unit, binSize, nValues,
Expand All @@ -211,7 +214,7 @@ static long readWholeNormalizationVector(long currentPosition, LittleEndianInput
int nNormalizationFactors = dis.readInt();
currentPosition += 4;

NormFactorMapReader hmReader = new NormFactorMapReader(nNormalizationFactors, version, dis);
NormFactorMapReader hmReader = new NormFactorMapReader(nNormalizationFactors, version, currentPosition, null);
currentPosition += hmReader.getOffset();

String key = ExpectedValueFunction.getKey(unit, binSize, norm);
Expand Down Expand Up @@ -290,4 +293,17 @@ static int[] readSites(long position, int nSites, String path) throws IOExceptio
}
return sites;
}

public static int readIntFromBytes(SeekableStream stream) throws IOException {
byte[] buffer = new byte[4];
int actualBytes = stream.read(buffer);
if (actualBytes == 4) {
LittleEndianInputStream dis = new LittleEndianInputStream(new ByteArrayInputStream(buffer));
return dis.readInt();
} else {
System.err.println("Actually read " + actualBytes + " bytes instead of " + 4);
System.exit(110);
return 0;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,33 @@
import java.util.List;
import java.util.Map;

public class BlockIndex {
protected final Map<Integer, IndexEntry> blockIndex;
public class BlockIndices {
protected final Map<Integer, IndexEntry> blockIndices;
protected final int numBlocks;

public BlockIndex(int nBlocks) {
public BlockIndices(int nBlocks) {
numBlocks = nBlocks;
blockIndex = new HashMap<>(nBlocks);
blockIndices = new HashMap<>(nBlocks);
}

public void populateBlocks(LittleEndianInputStream dis) throws IOException {
for (int b = 0; b < numBlocks; b++) {
int blockNumber = dis.readInt();
long filePosition = dis.readLong();
int blockSizeInBytes = dis.readInt();
blockIndex.put(blockNumber, new IndexEntry(filePosition, blockSizeInBytes));
blockIndices.put(blockNumber, new IndexEntry(filePosition, blockSizeInBytes));
}
}

public List<Integer> getBlockNumbers() {
return new ArrayList<>(blockIndex.keySet());
return new ArrayList<>(blockIndices.keySet());
}

public IndexEntry getBlock(int blockNumber) {
return blockIndex.get(blockNumber);
return blockIndices.get(blockNumber);
}

public void clearCache() {
blockIndices.clear();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,20 @@
import java.io.IOException;
import java.util.List;

public class DynamicBlockIndex extends BlockIndex {
public class DynamicBlockIndices extends BlockIndices {

private final int maxBlocks;
private final long minPosition, maxPosition;
private Integer blockNumberRangeMin = null, blockNumberRangeMax = null;
private Long mapFileBoundsMin = null, mapFileBoundsMax = null;
private final SeekableStream stream;

public DynamicBlockIndex(SeekableStream stream, int numBlocks, int maxBlocks, long minPosition) {
super(numBlocks);
public DynamicBlockIndices(SeekableStream stream, int numBlocks, int maxBlocks, long minPosition) {
super(numBlocks / 2); // when using dynamic blocks, the idea is to not load every single block?
this.stream = stream;
this.maxBlocks = maxBlocks;
this.minPosition = minPosition;
maxPosition = minPosition + numBlocks * 16;
maxPosition = minPosition + numBlocks * 16L;
}


Expand All @@ -59,8 +59,8 @@ public List<Integer> getBlockNumbers() {
public IndexEntry getBlock(int blockNumber) {
if (blockNumber > maxBlocks) {
return null;
} else if (blockIndex.containsKey(blockNumber)) {
return blockIndex.get(blockNumber);
} else if (blockIndices.containsKey(blockNumber)) {
return blockIndices.get(blockNumber);
} else if (blockNumber == 0) {
try {
return searchForBlockIndexEntry(blockNumber, this.minPosition, this.minPosition + 16);
Expand Down Expand Up @@ -109,7 +109,7 @@ private IndexEntry searchForBlockIndexEntry(int blockNumber, long boundsMin, lon
int blockNumberFound = dis.readInt();
long filePosition = dis.readLong();
int blockSizeInBytes = dis.readInt();
blockIndex.put(blockNumberFound, new IndexEntry(filePosition, blockSizeInBytes));
blockIndices.put(blockNumberFound, new IndexEntry(filePosition, blockSizeInBytes));
if (firstBlockNumber == null) firstBlockNumber = blockNumberFound;
lastBlockNumber = blockNumberFound;
pointer += 16;
Expand All @@ -122,7 +122,7 @@ private IndexEntry searchForBlockIndexEntry(int blockNumber, long boundsMin, lon
blockNumberRangeMax = lastBlockNumber;
}

return blockIndex.get(blockNumber);
return blockIndices.get(blockNumber);
}
// Midpoint in units of 16 byte chunks
int nEntries = (int) ((boundsMax - boundsMin) / 16);
Expand All @@ -141,8 +141,8 @@ private IndexEntry searchForBlockIndexEntry(int blockNumber, long boundsMin, lon
blockSizeInBytes = dis.readInt();
}
if (blockNumberFound == blockNumber) {
blockIndex.put(blockNumberFound, new IndexEntry(filePosition, blockSizeInBytes));
return blockIndex.get(blockNumber);
blockIndices.put(blockNumberFound, new IndexEntry(filePosition, blockSizeInBytes));
return blockIndices.get(blockNumber);
} else if (blockNumber > blockNumberFound) {
return searchForBlockIndexEntry(blockNumber, positionToSeek + 16, boundsMax);
} else {
Expand Down
Loading

0 comments on commit 43fad5e

Please sign in to comment.