-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert preprocessing functions to multithreaded with GIL-released #540
Conversation
This PR also removes the pyfftw dependency which is causing headaches with conda packaging - the numpy ffts are usually as fast or faster than the pyfftw ffts. |
…/EQcorrscan into preprocess-accelerate
I still have some tweaking to do to get the tests passing - I think this has to do with ensuring that the correct arguments are passed to the new function which does the job of two previous functions. I ran some quick benchmarks this evening though, Timings come from 12 channels x 1 day @ 100 Hz
24 channels x 1 day @ 100 Hz
So in these simple cases, the new processing is faster and significantly more memory efficient for parallel processing (but slightly worse for serial, this might be retaining more dictionaries across all traces through the processing?). This is heartening to see - @flixha do you think this is worth me pursuing further? I don't think it will take me too much work to iron out the discrepancies between old and new. If you are interested as well, because the internal functions are numpy and scipy, they could likely be replaced with cupy or similar gpu libs to allow GPU acceleration of the pre-processing. Using threading here should also allow for using multiprocessing for pre-computing successive steps in the main workflow without running into issues of spawning children from forked processes. |
At the moment I am getting fails in lag_calc_tests, but only when run alongside other tests which suggests an issue with state being retained between tests somewhere. A minimal test-set reproduction is: py.test eqcorrscan/tests/find_peaks_test.py::TestEdgeCases eqcorrscan/tests/find_peaks_test.py::TestPeakFindSpeeds::test_multi_find_peaks eqcorrscan/tests/lag_calc_test.py::SyntheticTests::test_family_picking -v -s which fails with: AssertionError: 0.8974865078926086 != 1.0 within 1 places (0.10251349210739136 difference) When run alone, this test passes. Something about the combination of find-peaks (with both the edge case tests and the speed tests) results in differences in the correlations from a different function - this may be due to differences in the dataset going in to lag-calc - I don't see fails with real data. |
It looks like the issue is fixed by putting the random_seed into the class setup, rather than in the main of the test file. |
…/EQcorrscan into preprocess-accelerate
…/EQcorrscan into preprocess-accelerate
I think this is ready for review now. |
…/EQcorrscan into preprocess-accelerate
Hi @calum-chamberlain, I had a chance to check this out a bit now; one thing I noticed is that previously, we kind of allowed traces with the same ID into the preprocessing functions (e.g., for templates with 2 channels on Z), but from what I saw for example here with the dicts with trace-ids, this would not work fully any more. One way to work around that could be to use trace-ID plus trace starttime as a key; although then the keys may need updating when the starttime changes with padding. For example like this: def _simple_qc(st, max_workers=None, chunksize=1):
"""
Multithreaded simple QC of data.
:param st: Stream of data to check
:type st: obspy.core.Stream
:param max_workers: Maximum number of threads to use
:type max_workers: int
:param chunksize: Number of traces to process per thread
:type chunksize: int
:return: dict of {tr.id: quality} where quality is bool
"""
qual = dict()
with ThreadPoolExecutor(max_workers) as executor:
for tr, _qual in zip(st, executor.map(
_check_daylong, (tr.data for tr in st), chunksize=chunksize)):
key = tr.id + '_' + str(tr.stats.starttime)
qual[key] = _qual
return qual
def multi_process(st, lowcut, highcut, filt_order, samp_rate, parallel=False,
num_cores=False, starttime=None, endtime=None,
daylong=False, seisan_chan_names=False, fill_gaps=True,
ignore_length=False, ignore_bad_data=False):
"""
Apply standardised processing workflow to data for matched-filtering
Steps:
#. Check length and continuity of data meets user-defined criteria
#. Fill remaining gaps in data with zeros and record gap positions
#. Detrend data (using a simple linear detrend to set start and
end to 0)
#. Pad data to length
#. Resample in the frequency domain
#. Detrend dat (using a simple linear detrend to set start and
end to 0)
#. Zerophase Butterworth filter
#. Re-check length
#. Re-apply zero-padding to gap locations recording in step 2 to remove
filtering and resampling artefacts
:param st: Stream to process
:type st: obspy.core.Stream
:param lowcut:
Lowcut of butterworth filter in Hz. If set to None and highcut is
given a highpass filter will be applied. If both lowcut and highcut
are given, a bandpass filter will be applied. If lowcut and highcut
are both None, no filtering will be applied.
:type lowcut: float
:param highcut:
Highcut of butterworth filter in Hz. If set to None and lowcut is
given a lowpass filter will be applied. If both lowcut and highcut
are given, a bandpass filter will be applied. If lowcut and highcut
are both None, no filtering will be applied.
:type highcut: float
:param filt_order: Filter order
:type filt_order: int
:param samp_rate: Desired sample rate of output data in Hz
:type samp_rate: float
:param parallel: Whether to process data in parallel (uses multi-threading)
:type parallel: bool
:param num_cores: Maximum number of cores to use for parallel processing
:type num_cores: int
:param starttime: Desired starttime of data
:type starttime: obspy.core.UTCDateTime
:param endtime: Desired endtime of data
:type endtime: obspy.core.UTCDateTime
:param daylong:
Whether data should be considered to be one-day long. Setting this will
assume that your data should start as close to the start of a day
as possible given the sampling.
:type daylong: bool
:param seisan_chan_names:
Whether to convert channel names to two-char seisan channel names
:type seisan_chan_names: bool
:param fill_gaps: Whether to fill-gaps in the data
:type fill_gaps: bool
:param ignore_length:
Whether to ignore data that are not long enough.
:type ignore_length: bool
:param ignore_bad_data: Whether to ignore data that are excessively gappy
:type ignore_bad_data: bool
:return: Processed stream as obspy.core.Stream
"""
outtic = default_timer()
if isinstance(st, Trace):
tracein = True
st = Stream(st)
else:
tracein = False
# Add sanity check for filter
if highcut and highcut >= 0.5 * samp_rate:
raise IOError('Highcut must be lower than the Nyquist')
if highcut and lowcut:
assert lowcut < highcut, f"Lowcut: {lowcut} above highcut: {highcut}"
# Allow datetimes for starttime and endtime
if starttime and not isinstance(starttime, UTCDateTime):
starttime = UTCDateTime(starttime)
if starttime is False:
starttime = None
if endtime and not isinstance(endtime, UTCDateTime):
endtime = UTCDateTime(endtime)
if endtime is False:
endtime = None
# Make sensible choices about workers and chunk sizes
if parallel:
if not num_cores:
# We don't want to over-specify threads, we don't have IO
# bound tasks
max_workers = min(len(st), os.cpu_count())
else:
max_workers = min(len(st), num_cores)
else:
max_workers = 1
chunksize = len(st) // max_workers
st, length, clip, starttime = _sanitize_length(
st=st, starttime=starttime, endtime=endtime, daylong=daylong)
for tr in st:
if len(tr.data) == 0:
st.remove(tr)
Logger.warning('No data for {0} after trim'.format(tr.id))
# Do work
# 1. Fill gaps and keep track of them
gappy = {tr.id + '_' + str(tr.stats.starttime): False for tr in st}
gaps = dict()
for i, tr in enumerate(st):
if isinstance(tr.data, np.ma.MaskedArray):
key = tr.id + '_' + str(tr.stats.starttime)
gappy[key] = True
gaps[key], tr = _fill_gaps(tr)
st[i] = tr
# 2. Check for zeros and cope with bad data
# ~ 4x speedup for 50 100 Hz daylong traces on 12 threads
qual = _simple_qc(st, max_workers=max_workers, chunksize=chunksize)
for key, _qual in qual.items():
if not _qual:
msg = ("Data have more zeros than actual data, please check the "
f"raw data set-up and manually sort it: {key}")
if not ignore_bad_data:
raise ValueError(msg)
else:
# Remove bad traces from the stream
remove_traces = [
tr for tr in st
if tr.id + '_' + str(tr.stats.starttime) == key]
# Need to check whether trace is still in stream
for tr in remove_traces:
if tr in st:
st.remove(tr)
# 3. Detrend
# ~ 2x speedup for 50 100 Hz daylong traces on 12 threads
st = _multi_detrend(st, max_workers=max_workers, chunksize=chunksize)
# 4. Check length and pad to length
padded = {tr.id + '_' + str(tr.stats.starttime): (0., 0.) for tr in st}
if clip:
st.trim(starttime, starttime + length, nearest_sample=True)
# Indexing because we are going to overwrite traces
for i, _ in enumerate(st):
if float(st[i].stats.npts / st[i].stats.sampling_rate) != length:
key = st[i].id + '_' + str(st[i].stats.starttime)
Logger.info(
'Data for {0} are not long-enough, will zero pad'.format(
key))
st[i], padded[key] = _length_check(
st[i], starttime=starttime, length=length,
ignore_length=ignore_length,
ignore_bad_data=ignore_bad_data)
# Update padded-dict with updated keys
if st[i] is not None:
new_key = st[i] + '_' + st[i].stats.starttime
padded[new_key] = padded.pop(key)
gappy[new_key] = gappy.pop(key)
gaps[new_key] = gaps.pop(key)
# Remove None traces that might be returned from length checking
st.traces = [tr for tr in st if tr is not None]
# Check that we actually still have some data
if not _stream_has_data(st):
if tracein:
return st[0]
return st
# 5. Resample
# ~ 3.25x speedup for 50 100 Hz daylong traces on 12 threads
st = _multi_resample(
st, sampling_rate=samp_rate, max_workers=max_workers,
chunksize=chunksize)
# Detrend again before filtering
st = _multi_detrend(st, max_workers=max_workers, chunksize=chunksize)
# 6. Filter
# ~3.25x speedup for 50 100 Hz daylong traces on 12 threads
st = _multi_filter(
st, highcut=highcut, lowcut=lowcut, filt_order=filt_order,
max_workers=max_workers, chunksize=chunksize)
# 7. Reapply zeros after processing from 4
for tr in st:
# Pads default to (0., 0.), pads should only ever be positive.
key = tr.id + '_' + str(tr.stats.starttime)
if sum(padded[key]) == 0:
continue
Logger.debug("Reapplying zero pads post processing")
Logger.debug(str(tr))
pre_pad = np.zeros(int(padded[key][0] * tr.stats.sampling_rate))
post_pad = np.zeros(int(padded[key][1] * tr.stats.sampling_rate))
pre_pad_len = len(pre_pad)
post_pad_len = len(post_pad)
Logger.debug(
f"Taking only valid data between {pre_pad_len} and "
f"{tr.stats.npts - post_pad_len} samples")
# Re-apply the pads, taking only the data section that was valid
tr.data = np.concatenate(
[pre_pad, tr.data[pre_pad_len: len(tr.data) - post_pad_len],
post_pad])
Logger.debug(str(tr))
# 8. Recheck length
for tr in st:
if float(tr.stats.npts * tr.stats.delta) != length and clip:
key = tr.id + '_' + str(tr.stats.starttime)
Logger.info(f'Data for {key} are not of required length, will '
f'zero pad')
# Use obspy's trim function with zero padding
tr = tr.trim(starttime, starttime + length, pad=True, fill_value=0,
nearest_sample=True)
# Update dicts in case key changed with starttime
new_key = st[i] + '_' + st[i].stats.starttime
padded[new_key] = padded.pop(key)
gappy[new_key] = gappy.pop(key)
gaps[new_key] = gaps.pop(key)
# If there is one sample too many after this remove the last one
# by convention
if len(tr.data) == (length * tr.stats.sampling_rate) + 1:
tr.data = tr.data[1:len(tr.data)]
if abs((tr.stats.sampling_rate * length) -
tr.stats.npts) > tr.stats.delta:
raise ValueError('Data are not required length for ' +
tr.stats.station + '.' + tr.stats.channel)
# 9. Re-insert gaps from 1
for i, tr in enumerate(st):
if gappy[tr.id + '_' + str(tr.stats.starttime)]:
key = tr.id + '_' + str(tr.stats.starttime)
st[i] = _zero_pad_gaps(tr, gaps[key], fill_gaps=fill_gaps)
# 10. Clean up
for tr in st:
if len(tr.data) == 0:
st.remove(tr)
# 11. Account for seisan channel naming
if seisan_chan_names:
for tr in st:
tr.stats.channel = tr.stats.channel[0] + tr.stats.channel[-1]
if tracein:
st.merge()
return st[0]
outtoc = default_timer()
Logger.info('Pre-processing took: {0:.4f}s'.format(outtoc - outtic))
return st |
Good spot @flixha - although I never intended pre-processing to allow multiple traces for a single channel and the docs have always said that data should be merged before passing to pre-processing functions (e.g. the note here). Handling multiple channels in a template is done during template cutting after pre-processing. |
What does this PR do?
Pre-processing data is a major time sink in EQcorrscan. To accelerate this previously we used multiprocessing, which comes with a significant memory overhead as data are copied between processors. The parts of the processing flow that take the most time though are the resampling and filtering. These components can be implemented such that they are only running numpy and scipy functions that release the GIL. By re-writing these to be GIL-releasing functions we can make use of lighter multi-threading to parallel across traces.
This PR attempts to do this by re-writing the pre-processing workflow into a more logical flow, with multithreading implemented for resampling, detrending and filtering. There are likely other efficiencies that could be made for gappy data.
Why was it initiated? Any relevant Issues?
EQcorrscan is a bit slow in places!
PR Checklist
develop
base branch selected?CHANGES.md
.- [ ] First time contributors have added your name toCONTRIBUTORS.md
.TODO:
shortproc
,dayproc
andprocess
with calls tomulti_process