Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert preprocessing functions to multithreaded with GIL-released #540

Merged
merged 36 commits into from
Mar 29, 2023

Conversation

calum-chamberlain
Copy link
Member

@calum-chamberlain calum-chamberlain commented Mar 19, 2023

What does this PR do?

Pre-processing data is a major time sink in EQcorrscan. To accelerate this previously we used multiprocessing, which comes with a significant memory overhead as data are copied between processors. The parts of the processing flow that take the most time though are the resampling and filtering. These components can be implemented such that they are only running numpy and scipy functions that release the GIL. By re-writing these to be GIL-releasing functions we can make use of lighter multi-threading to parallel across traces.

This PR attempts to do this by re-writing the pre-processing workflow into a more logical flow, with multithreading implemented for resampling, detrending and filtering. There are likely other efficiencies that could be made for gappy data.

Why was it initiated? Any relevant Issues?

EQcorrscan is a bit slow in places!

PR Checklist

  • develop base branch selected?
  • This PR is not directly related to an existing issue (which has no PR yet).
  • All tests still pass.
  • Any new features or fixed regressions are be covered via new tests.
  • Any new or changed features have are fully documented.
  • Significant changes have been added to CHANGES.md.
    - [ ] First time contributors have added your name to CONTRIBUTORS.md.

TODO:

  1. Confirm that the new proprocessing exactly matches the previous preprocessing
  2. Provide benchmarks for times and memory use between new and old functions
  3. Replace calls to shortproc, dayproc and process with calls to multi_process
  4. Ensure all tests pass after this (the preprocessing tests pass as of making this PR, and new tests have been added to compare the resampling, filtering and detrending methods to obspy native methods)
  5. Docstrings, and make sure new funcs get into online docs

@calum-chamberlain
Copy link
Member Author

This PR also removes the pyfftw dependency which is causing headaches with conda packaging - the numpy ffts are usually as fast or faster than the pyfftw ffts.

@calum-chamberlain
Copy link
Member Author

calum-chamberlain commented Mar 20, 2023

I still have some tweaking to do to get the tests passing - I think this has to do with ensuring that the correct arguments are passed to the new function which does the job of two previous functions. I ran some quick benchmarks this evening though, but the memory use is a little spurious, so I will work out what is going on there before posting that (when including children the old dayproc reports 75GB peak memory use on my 48GB RAM machine...) Edited to include memory from mprof*. Tests were run on my AMD Ryzen 5 6-core, 12 thread CPU.

Timings come from %timeit run using random data, and filtering 2-10 Hz, and resampling to 25 Hz.

12 channels x 1 day @ 100 Hz

Code Time (s) Memory (MiB)
Old parallel 5.18 15,203
Old serial 8.79 1,296
New parallel 2.93 3,389
New serial 5.45 1,449

24 channels x 1 day @ 100 Hz

Code Time (s) Memory (MiB)
Old parallel 8.51 26,460
Old serial 18.6 2,134
New parallel 6.07 4,954
New serial 10.9 2,438

So in these simple cases, the new processing is faster and significantly more memory efficient for parallel processing (but slightly worse for serial, this might be retaining more dictionaries across all traces through the processing?).

This is heartening to see - @flixha do you think this is worth me pursuing further? I don't think it will take me too much work to iron out the discrepancies between old and new. If you are interested as well, because the internal functions are numpy and scipy, they could likely be replaced with cupy or similar gpu libs to allow GPU acceleration of the pre-processing. Using threading here should also allow for using multiprocessing for pre-computing successive steps in the main workflow without running into issues of spawning children from forked processes.

@calum-chamberlain
Copy link
Member Author

At the moment I am getting fails in lag_calc_tests, but only when run alongside other tests which suggests an issue with state being retained between tests somewhere. A minimal test-set reproduction is:

py.test eqcorrscan/tests/find_peaks_test.py::TestEdgeCases eqcorrscan/tests/find_peaks_test.py::TestPeakFindSpeeds::test_multi_find_peaks eqcorrscan/tests/lag_calc_test.py::SyntheticTests::test_family_picking -v -s

which fails with:

AssertionError: 0.8974865078926086 != 1.0 within 1 places (0.10251349210739136 difference)

When run alone, this test passes. Something about the combination of find-peaks (with both the edge case tests and the speed tests) results in differences in the correlations from a different function - this may be due to differences in the dataset going in to lag-calc - I don't see fails with real data.

@calum-chamberlain
Copy link
Member Author

It looks like the issue is fixed by putting the random_seed into the class setup, rather than in the main of the test file.

@calum-chamberlain
Copy link
Member Author

I think this is ready for review now.

@calum-chamberlain calum-chamberlain changed the title [WIP] Convert preprocessing functions to multithreaded with GIL-released Convert preprocessing functions to multithreaded with GIL-released Mar 22, 2023
@calum-chamberlain calum-chamberlain merged commit 6182e38 into develop Mar 29, 2023
@calum-chamberlain calum-chamberlain deleted the preprocess-accelerate branch March 29, 2023 03:17
@flixha
Copy link
Collaborator

flixha commented Apr 5, 2023

Hi @calum-chamberlain,
sorry for not being able to comment on this earlier. This change looks very good indeed! Definitely looking forward to trying out the GPU-version as well, and exploring how far cupy works with rocm.

I had a chance to check this out a bit now; one thing I noticed is that previously, we kind of allowed traces with the same ID into the preprocessing functions (e.g., for templates with 2 channels on Z), but from what I saw for example here with the dicts with trace-ids, this would not work fully any more.

One way to work around that could be to use trace-ID plus trace starttime as a key; although then the keys may need updating when the starttime changes with padding. For example like this:

def _simple_qc(st, max_workers=None, chunksize=1):
    """
    Multithreaded simple QC of data.

    :param st: Stream of data to check
    :type st: obspy.core.Stream
    :param max_workers: Maximum number of threads to use
    :type max_workers: int
    :param chunksize: Number of traces to process per thread
    :type chunksize: int

    :return: dict of {tr.id: quality} where quality is bool
    """
    qual = dict()
    with ThreadPoolExecutor(max_workers) as executor:
        for tr, _qual in zip(st, executor.map(
                _check_daylong, (tr.data for tr in st), chunksize=chunksize)):
            key = tr.id + '_' + str(tr.stats.starttime)
            qual[key] = _qual
    return qual

def multi_process(st, lowcut, highcut, filt_order, samp_rate, parallel=False,
                  num_cores=False, starttime=None, endtime=None,
                  daylong=False, seisan_chan_names=False, fill_gaps=True,
                  ignore_length=False, ignore_bad_data=False):
    """
    Apply standardised processing workflow to data for matched-filtering

    Steps:

        #. Check length and continuity of data meets user-defined criteria
        #. Fill remaining gaps in data with zeros and record gap positions
        #. Detrend data (using a simple linear detrend to set start and
           end to 0)
        #. Pad data to length
        #. Resample in the frequency domain
        #. Detrend dat (using a simple linear detrend to set start and
           end to 0)
        #. Zerophase Butterworth filter
        #. Re-check length
        #. Re-apply zero-padding to gap locations recording in step 2 to remove
           filtering and resampling artefacts

    :param st: Stream to process
    :type st: obspy.core.Stream
    :param lowcut:
        Lowcut of butterworth filter in Hz. If set to None and highcut is
        given a highpass filter will be applied. If both lowcut and highcut
        are given, a bandpass filter will be applied. If lowcut and highcut
        are both None, no filtering will be applied.
    :type lowcut: float
    :param highcut:
        Highcut of butterworth filter in Hz. If set to None and lowcut is
        given a lowpass filter will be applied. If both lowcut and highcut
        are given, a bandpass filter will be applied. If lowcut and highcut
        are both None, no filtering will be applied.
    :type highcut: float
    :param filt_order: Filter order
    :type filt_order: int
    :param samp_rate: Desired sample rate of output data in Hz
    :type samp_rate: float
    :param parallel: Whether to process data in parallel (uses multi-threading)
    :type parallel: bool
    :param num_cores: Maximum number of cores to use for parallel processing
    :type num_cores: int
    :param starttime: Desired starttime of data
    :type starttime: obspy.core.UTCDateTime
    :param endtime: Desired endtime of data
    :type endtime: obspy.core.UTCDateTime
    :param daylong:
        Whether data should be considered to be one-day long. Setting this will
        assume that your data should start as close to the start of a day
        as possible given the sampling.
    :type daylong: bool
    :param seisan_chan_names:
        Whether to convert channel names to two-char seisan channel names
    :type seisan_chan_names: bool
    :param fill_gaps: Whether to fill-gaps in the data
    :type fill_gaps: bool
    :param ignore_length:
        Whether to ignore data that are not long enough.
    :type ignore_length: bool
    :param ignore_bad_data: Whether to ignore data that are excessively gappy
    :type ignore_bad_data: bool

    :return: Processed stream as obspy.core.Stream
    """
    outtic = default_timer()
    if isinstance(st, Trace):
        tracein = True
        st = Stream(st)
    else:
        tracein = False
    # Add sanity check for filter
    if highcut and highcut >= 0.5 * samp_rate:
        raise IOError('Highcut must be lower than the Nyquist')
    if highcut and lowcut:
        assert lowcut < highcut, f"Lowcut: {lowcut} above highcut: {highcut}"

    # Allow datetimes for starttime and endtime
    if starttime and not isinstance(starttime, UTCDateTime):
        starttime = UTCDateTime(starttime)
    if starttime is False:
        starttime = None
    if endtime and not isinstance(endtime, UTCDateTime):
        endtime = UTCDateTime(endtime)
    if endtime is False:
        endtime = None

    # Make sensible choices about workers and chunk sizes
    if parallel:
        if not num_cores:
            # We don't want to over-specify threads, we don't have IO
            # bound tasks
            max_workers = min(len(st), os.cpu_count())
        else:
            max_workers = min(len(st), num_cores)
    else:
        max_workers = 1
    chunksize = len(st) // max_workers

    st, length, clip, starttime = _sanitize_length(
        st=st, starttime=starttime, endtime=endtime, daylong=daylong)

    for tr in st:
        if len(tr.data) == 0:
            st.remove(tr)
            Logger.warning('No data for {0} after trim'.format(tr.id))

    # Do work
    # 1. Fill gaps and keep track of them
    gappy = {tr.id + '_' + str(tr.stats.starttime): False for tr in st}
    gaps = dict()
    for i, tr in enumerate(st):
        if isinstance(tr.data, np.ma.MaskedArray):
            key = tr.id + '_' + str(tr.stats.starttime)
            gappy[key] = True
            gaps[key], tr = _fill_gaps(tr)
            st[i] = tr

    # 2. Check for zeros and cope with bad data
    # ~ 4x speedup for 50 100 Hz daylong traces on 12 threads
    qual = _simple_qc(st, max_workers=max_workers, chunksize=chunksize)
    for key, _qual in qual.items():
        if not _qual:
            msg = ("Data have more zeros than actual data, please check the "
                   f"raw data set-up and manually sort it: {key}")
            if not ignore_bad_data:
                raise ValueError(msg)
            else:
                # Remove bad traces from the stream
                remove_traces = [
                    tr for tr in st 
                    if tr.id + '_' + str(tr.stats.starttime) == key]
                # Need to check whether trace is still in stream
                for tr in remove_traces:
                    if tr in st:
                        st.remove(tr)

    # 3. Detrend
    # ~ 2x speedup for 50 100 Hz daylong traces on 12 threads
    st = _multi_detrend(st, max_workers=max_workers, chunksize=chunksize)

    # 4. Check length and pad to length
    padded = {tr.id + '_' + str(tr.stats.starttime): (0., 0.) for tr in st}
    if clip:
        st.trim(starttime, starttime + length, nearest_sample=True)
        # Indexing because we are going to overwrite traces
        for i, _ in enumerate(st):
            if float(st[i].stats.npts / st[i].stats.sampling_rate) != length:
                key = st[i].id + '_' + str(st[i].stats.starttime)
                Logger.info(
                    'Data for {0} are not long-enough, will zero pad'.format(
                        key))
                st[i], padded[key] = _length_check(
                    st[i], starttime=starttime, length=length,
                    ignore_length=ignore_length,
                    ignore_bad_data=ignore_bad_data)
                # Update padded-dict with updated keys
                if st[i] is not None:
                    new_key = st[i] + '_' + st[i].stats.starttime
                    padded[new_key] = padded.pop(key)
                    gappy[new_key] = gappy.pop(key)
                    gaps[new_key] = gaps.pop(key)
        # Remove None traces that might be returned from length checking
        st.traces = [tr for tr in st if tr is not None]

    # Check that we actually still have some data
    if not _stream_has_data(st):
        if tracein:
            return st[0]
        return st

    # 5. Resample
    # ~ 3.25x speedup for 50 100 Hz daylong traces on 12 threads
    st = _multi_resample(
        st, sampling_rate=samp_rate, max_workers=max_workers,
        chunksize=chunksize)

    # Detrend again before filtering
    st = _multi_detrend(st, max_workers=max_workers, chunksize=chunksize)

    # 6. Filter
    # ~3.25x speedup for 50 100 Hz daylong traces on 12 threads
    st = _multi_filter(
        st, highcut=highcut, lowcut=lowcut, filt_order=filt_order,
        max_workers=max_workers, chunksize=chunksize)

    # 7. Reapply zeros after processing from 4
    for tr in st:
        # Pads default to (0., 0.), pads should only ever be positive.
        key = tr.id + '_' + str(tr.stats.starttime)
        if sum(padded[key]) == 0:
            continue
        Logger.debug("Reapplying zero pads post processing")
        Logger.debug(str(tr))
        pre_pad = np.zeros(int(padded[key][0] * tr.stats.sampling_rate))
        post_pad = np.zeros(int(padded[key][1] * tr.stats.sampling_rate))
        pre_pad_len = len(pre_pad)
        post_pad_len = len(post_pad)
        Logger.debug(
            f"Taking only valid data between {pre_pad_len} and "
            f"{tr.stats.npts - post_pad_len} samples")
        # Re-apply the pads, taking only the data section that was valid
        tr.data = np.concatenate(
            [pre_pad, tr.data[pre_pad_len: len(tr.data) - post_pad_len],
             post_pad])
        Logger.debug(str(tr))

    # 8. Recheck length
    for tr in st:
        if float(tr.stats.npts * tr.stats.delta) != length and clip:
            key = tr.id + '_' + str(tr.stats.starttime)
            Logger.info(f'Data for {key} are not of required length, will '
                        f'zero pad')
            # Use obspy's trim function with zero padding
            tr = tr.trim(starttime, starttime + length, pad=True, fill_value=0,
                         nearest_sample=True)
            # Update dicts in case key changed with starttime
            new_key = st[i] + '_' + st[i].stats.starttime
            padded[new_key] = padded.pop(key)
            gappy[new_key] = gappy.pop(key)
            gaps[new_key] = gaps.pop(key)
            # If there is one sample too many after this remove the last one
            # by convention
            if len(tr.data) == (length * tr.stats.sampling_rate) + 1:
                tr.data = tr.data[1:len(tr.data)]
            if abs((tr.stats.sampling_rate * length) -
                   tr.stats.npts) > tr.stats.delta:
                raise ValueError('Data are not required length for ' +
                                 tr.stats.station + '.' + tr.stats.channel)

    # 9. Re-insert gaps from 1
    for i, tr in enumerate(st):
        if gappy[tr.id + '_' + str(tr.stats.starttime)]:
            key = tr.id + '_' + str(tr.stats.starttime)
            st[i] = _zero_pad_gaps(tr, gaps[key], fill_gaps=fill_gaps)

    # 10. Clean up
    for tr in st:
        if len(tr.data) == 0:
            st.remove(tr)

    # 11. Account for seisan channel naming
    if seisan_chan_names:
        for tr in st:
            tr.stats.channel = tr.stats.channel[0] + tr.stats.channel[-1]

    if tracein:
        st.merge()
        return st[0]
    outtoc = default_timer()
    Logger.info('Pre-processing took: {0:.4f}s'.format(outtoc - outtic))
    return st

@calum-chamberlain
Copy link
Member Author

Good spot @flixha - although I never intended pre-processing to allow multiple traces for a single channel and the docs have always said that data should be merged before passing to pre-processing functions (e.g. the note here). Handling multiple channels in a template is done during template cutting after pre-processing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants