diff --git a/bin/minifollowups/pycbc_injection_minifollowup b/bin/minifollowups/pycbc_injection_minifollowup index dab74e1d613..be849fc3bbf 100644 --- a/bin/minifollowups/pycbc_injection_minifollowup +++ b/bin/minifollowups/pycbc_injection_minifollowup @@ -182,12 +182,13 @@ trigger_times = {} for trig in single_triggers: ifo = trig.ifo with HFile(trig.lfn, 'r') as trig_f: - trigger_idx[ifo], trigger_times[ifo], trigger_snrs[ifo] = \ + trigger_idx[ifo], data_tuple = \ trig_f.select( nearby_missedinj, f'{ifo}/end_time', f'{ifo}/snr', - return_indices=True) + ) + trigger_times[ifo], trigger_snrs[ifo] = data_tuple if len(missed) < num_events: num_events = len(missed) @@ -309,10 +310,10 @@ for num_event in range(num_events): # Finding loudest template in this detector near to the injection: # First, find triggers close to the missed injection single_fname = args.single_detector_triggers[curr_ifo] - idx = HFile(single_fname).select( + idx, _ = HFile(single_fname).select( lambda t: abs(t - inj_params['tc']) < args.inj_window, f'{curr_ifo}/end_time', - indices_only=True, + return_data=False, ) if len(idx) == 0: diff --git a/bin/minifollowups/pycbc_page_snglinfo b/bin/minifollowups/pycbc_page_snglinfo index e41be46c5e1..5d741e66946 100644 --- a/bin/minifollowups/pycbc_page_snglinfo +++ b/bin/minifollowups/pycbc_page_snglinfo @@ -159,7 +159,8 @@ else: # Name would be too long - just call it ranking statistic stat_name = 'Ranking Statistic' stat_name_long = ' with '.join( - [args.ranking_statistic, args.sngl_ranking]) + [args.ranking_statistic, args.sngl_ranking] + ) headers.append(stat_name) diff --git a/bin/minifollowups/pycbc_plot_trigger_timeseries b/bin/minifollowups/pycbc_plot_trigger_timeseries index 6cc5c575375..22e60dc4786 100644 --- a/bin/minifollowups/pycbc_plot_trigger_timeseries +++ b/bin/minifollowups/pycbc_plot_trigger_timeseries @@ -61,11 +61,11 @@ for ifo in args.single_trigger_files.keys(): # Identify trigger idxs within window of trigger time with HFile(args.single_trigger_files[ifo], 'r') as data: - idx = data.select( + idx, _ = data.select( lambda endtime: abs(endtime - t) < args.window, 'end_time', group=ifo, - indices_only=True + return_data=False, ) data_mask = numpy.zeros(data[ifo]['snr'].size, dtype=bool) data_mask[idx] = True diff --git a/bin/plotting/pycbc_plot_singles_timefreq b/bin/plotting/pycbc_plot_singles_timefreq index 20f480240f9..e8d12e9575a 100644 --- a/bin/plotting/pycbc_plot_singles_timefreq +++ b/bin/plotting/pycbc_plot_singles_timefreq @@ -105,13 +105,17 @@ def rough_filter(snr, chisq, chisq_dof, end_time, tmp_id, tmp_dur): return np.logical_and(end_time > opts.gps_start_time, end_time < opts.gps_end_time + tmp_dur) -indices, snr, chisq, chisq_dof, end_time, template_ids, template_duration = \ - trig_f.select(rough_filter, opts.detector + '/snr', - opts.detector + '/chisq', opts.detector + '/chisq_dof', - opts.detector + '/end_time', - opts.detector + '/template_id', - opts.detector + '/template_duration', - return_indices=True) +indices, data_tuple = trig_f.select( + rough_filter, + 'snr', + 'chisq', + 'chisq_dof', + 'end_time', + 'template_id', + 'template_duration', + group=opts.detector +) +snr, chisq, chisq_dof, end_time, template_ids, template_duration = data_tuple if len(indices) > 0: if opts.veto_file: diff --git a/bin/plotting/pycbc_plot_singles_vs_params b/bin/plotting/pycbc_plot_singles_vs_params index e3ab1ca3852..594b0ac0399 100644 --- a/bin/plotting/pycbc_plot_singles_vs_params +++ b/bin/plotting/pycbc_plot_singles_vs_params @@ -90,11 +90,11 @@ if opts.min_snr: n_triggers_orig = trig_file[f'{opts.detector}/snr'].size logging.info("Trigger file has %d triggers", n_triggers_orig) logging.info('Generating trigger mask (on SNR)') - idx = trig_file.select( + idx, _ = trig_file.select( lambda snr: snr >= opts.min_snr, 'snr', group=opts.detector, - indices_only=True, + return_data=False, ) logging.info('%d triggers after snr mask', idx.size) data_mask = np.zeros(n_triggers_orig, dtype=bool) @@ -139,12 +139,12 @@ if opts.max_y is not None: x = x[mask] y = y[mask] -title = '%s of %s triggers over %s and %s' % (opts.z_var.title(), - opts.detector, opts.x_var.title(), opts.y_var.title()) -fig_caption = ("This plot shows the %s of single detector triggers for the %s " - "detector. %s is shown on the colorbar axis against %s and %s " - "on the x- and y-axes." % (opts.z_var, opts.detector, - opts.z_var.title(), opts.x_var, opts.y_var)) +title = f'{opts.z_var.title()} of {opts.detector} triggers ' + \ + f'over {opts.x_var.title()} and {opts.y_var.title()}' +fig_caption = f"This plot shows the {opts.z_var} of single detector " + \ + f"triggers for the {opts.detector} detector. " + \ + f"{opts.z_var.title()} is shown on the colorbar axis " + \ + f"against {opts.x_var} and {opts.y_var} on the x- and y-axes." if not any(mask): # All triggers removed - make a blank plot which says so: diff --git a/examples/gw150914/PyCBCInspiral.ipynb b/examples/gw150914/PyCBCInspiral.ipynb index 1b58916781f..4ee1e964e70 100644 --- a/examples/gw150914/PyCBCInspiral.ipynb +++ b/examples/gw150914/PyCBCInspiral.ipynb @@ -909,14 +909,14 @@ "source": [ "h1_triggers = pycbc.io.hdf.SingleDetTriggers(\n", " 'H1-INSPIRAL_FULL_DATA_JOB0-1126257771-1837.hdf',\n", - " 'H1'\n", - " 'bank_file=H1L1-GW150914_BANK-1126051217-3331800.hdf',\n", + " 'H1L1-GW150914_BANK-1126051217-3331800.hdf',\n", + " None, None, None, 'H1'\n", ")\n", "\n", "l1_triggers = pycbc.io.hdf.SingleDetTriggers(\n", " 'L1-INSPIRAL_FULL_DATA_JOB0-1126258302-1591.hdf',\n", - " 'L1'\n", - " 'bank_file=H1L1-GW150914_BANK-1126051217-3331800.hdf',\n", + " 'H1L1-GW150914_BANK-1126051217-3331800.hdf',\n", + " None, None, None, 'L1'\n", ")\n", "\n", "imax = np.argmax(h1_triggers.snr)\n", diff --git a/pycbc/io/hdf.py b/pycbc/io/hdf.py index 44ee1d6398a..08b2c8d7c11 100644 --- a/pycbc/io/hdf.py +++ b/pycbc/io/hdf.py @@ -29,7 +29,8 @@ class HFile(h5py.File): """ Low level extensions to the capabilities of reading an hdf5 File """ - def select(self, fcn, *args, **kwds): + def select(self, fcn, *args, chunksize=10**6, derived=None, group='', + return_data=True, premask=None): """ Return arrays from an hdf5 file that satisfy the given function Parameters @@ -42,13 +43,14 @@ def select(self, fcn, *args, **kwds): A variable number of strings that are keys into the hdf5. These must refer to arrays of equal length. - chunksize : {1e6, int}, optional + chunksize : {10**6, int}, optional Number of elements to read and process at a time. derived : dictionary - Dictionary keyed on function, values are the list of required - datasets. If giving dataset outputs, these will be added at the - end. The function must take in a dictionary keyed on dataset names. + Dictionary keyed on argument name (must be given in args), values + are a tuple of: the function to be computed, and the required + datasets. The function must take in a dictionary keyed on those + dataset names. group : string, optional The group within the h5py file containing the datasets, e.g. in @@ -56,21 +58,23 @@ def select(self, fcn, *args, **kwds): can be included in the args manually, but is required in the case of derived functions, e.g. newsnr. - return_indices : bool, optional - If True, also return the indices of elements passing the function. - - indices_only : bool, optional - If True, only return the indices of elements passing the function. + return_data : bool, optional, default True + If True, return the data for elements passing the function. premask : array of boolean values, optional The pre-mask to apply to the triggers at read-in. Returns ------- - values : np.ndarrays - A variable number of arrays depending on the number of keys into - the hdf5 file that are given. If return_indices is True, the first - element is an array of indices of elements passing the function. + indices: np.ndarray + An array of indices of elements passing the function. + + return_tuple : tuple of np.ndarrays + A variable number of arrays depending on the number of + args provided, + If return_data is True, arrays are the values of each + arg. + If return_data is False, this is None. >>> f = HFile(filename) >>> snr = f.select(lambda snr: snr > 6, 'H1/snr') @@ -78,9 +82,9 @@ def select(self, fcn, *args, **kwds): # Required datasets are the arguments requested and datasets given # for any derived functions - derived = kwds.get('derived', {}) - dsets = list(args) - for rqd_list in derived.values(): + derived = derived if derived is not None else {} + dsets = [a for a in list(args) if a not in derived] + for _, rqd_list in derived.values(): dsets += rqd_list # remove any duplicates from req_dsets @@ -90,7 +94,6 @@ def select(self, fcn, *args, **kwds): # check they can all be used together refs = {} size = None - group = kwds.get('group', '') for ds in dsets: refs[ds] = self[group + '/' + ds] if (size is not None) and (refs[ds].size != size): @@ -99,13 +102,11 @@ def select(self, fcn, *args, **kwds): f"previous input datasets ({size}).") size = refs[ds].size - # To conserve memory read the array in chunks - chunksize = kwds.get('chunksize', int(1e6)) - - if 'premask' not in kwds or kwds.get('premask') is None: + # Apply any pre-masks + if premask is None: mask = np.ones(size, dtype=bool) else: - mask = kwds['premask'] + mask = premask if not mask.dtype == bool: # mask is an array of indices rather than booleans, @@ -118,19 +119,13 @@ def select(self, fcn, *args, **kwds): raise RuntimeError(f"Using premask of size {mask.size} which " f"does not match the input datasets ({size}).") - # This will be the outputs: - return_indices = kwds.get('return_indices', False) - indices_only = kwds.get('indices_only', False) - - # Arguments being returned: - # The name doesn't matter, so key on the function of - # derived datasets - ret_args = args + tuple(derived.keys()) + # datasets being returned (possibly) data = {} indices = np.array([], dtype=np.uint64) - for arg in ret_args: + for arg in args: data[arg] = [] + # Loop through the chunks: i = 0 while i < size: r = i + chunksize if i + chunksize < size else size @@ -143,32 +138,36 @@ def select(self, fcn, *args, **kwds): # Read each chunk's worth of data partial_data = {arg: refs[arg][i:r][mask[i:r]] for arg in dsets} - partial = [partial_data[a] for a in args] - partial += [func(partial_data) for func in derived.keys()] + partial = [] + for a in args: + if a in derived.keys(): + # If this is a derived dataset, calculate it + derived_fcn = derived[a][0] + partial += [derived_fcn(partial_data)] + else: + # otherwise, just read from the file + partial += [partial_data[a]] + # Find where it passes the function keep = fcn(*partial) - if return_indices or indices_only: - indices = np.concatenate([indices, np.flatnonzero(keep) + i]) - # Store only the results that pass the function - for arg, part in zip(ret_args, partial): - if not indices_only: + # Keep the indices which pass the function: + indices = np.concatenate([indices, np.flatnonzero(keep) + i]) + + if return_data: + # Store the dataset results that pass the function + for arg, part in zip(args, partial): data[arg].append(part[keep]) i += chunksize - return_tuple = tuple() - # Combine the partial results into full arrays - if indices_only or return_indices: - return_tuple += (indices.astype(np.uint64),) - if not indices_only: - return_tuple += tuple(np.concatenate(data[arg]) - for arg in ret_args) - - if len(return_tuple) == 1: - return return_tuple[0] + if return_data: + return_tuple = tuple(np.concatenate(data[arg]) + for arg in args) else: - return return_tuple + return_tuple = None + + return indices.astype(np.uint64), return_tuple class DictArray(object): @@ -471,7 +470,7 @@ class SingleDetTriggers(object): """ def __init__(self, trig_file, detector, bank_file=None, veto_file=None, segment_name=None, premask=None, filter_rank=None, - filter_threshold=None, chunksize=int(1e6), filter_func=None): + filter_threshold=None, chunksize=10**6, filter_func=None): """ Create a SingleDetTriggers instance @@ -503,7 +502,7 @@ def __init__(self, trig_file, detector, bank_file=None, veto_file=None, filter_threshold: float, required if filter_rank is used Threshold to filter the ranking values - chunksize : int , default 1e6 + chunksize : int , default 10**6 Size of chunks to read in for the filter_rank / threshold. """ logging.info('Loading triggers') @@ -527,11 +526,13 @@ def __init__(self, trig_file, detector, bank_file=None, veto_file=None, assert filter_threshold is not None logging.info("Applying threshold of %.3f on %s", filter_threshold, filter_rank) - idx = self.trigs_f.select( + fcn_dsets = (ranking.sngls_ranking_function_dict[filter_rank], + ranking.required_datasets[filter_rank]) + idx, _ = self.trigs_f.select( lambda rank: rank > filter_threshold, - derived={ranking.sngls_ranking_function_dict[filter_rank]: - ranking.required_datasets[filter_rank]}, - indices_only=True, + filter_rank, + derived={filter_rank: fcn_dsets}, + return_data=False, premask=self.mask, group=detector, chunksize=chunksize,