From dce6e91597b60b894d9c32a0571fac4fdc7fc34b Mon Sep 17 00:00:00 2001 From: Alexander Harvey Nitz Date: Tue, 11 Aug 2020 20:47:38 +0200 Subject: [PATCH] extend pycbc_create_injections to work for search workflows --- bin/pycbc_create_injections | 61 ++++++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/bin/pycbc_create_injections b/bin/pycbc_create_injections index 496373f7195..fa84b3743e6 100644 --- a/bin/pycbc_create_injections +++ b/bin/pycbc_create_injections @@ -135,7 +135,7 @@ parser.add_argument('--ninjections', type=int, parser.add_argument('--gps-start-time', type=int, help="Alternative to " "ninjections argument. Injections will be distributed " "in time by taking the cumulative sum of the tc values. " - "The number will be chosen to fill the chosen time range. " + "The number will be chosen to fill the chosen time range. ") parser.add_argument('--gps-end-time', type=int, help="Alternative to " "ninjections argument. Injections will be distributed " "in time by taking the cumulative sum of the tc values. " @@ -169,7 +169,7 @@ pycbc.init_logging(opts.verbose) if os.path.exists(opts.output_file) and not opts.force: raise OSError("output-file already exists; use --force if you wish to " "overwrite it.") - + if opts.ninjections and (opts.gps_start_time or opts.gps_end_time): raise ValueError("Cannot provide both ninjections and start/end time.") @@ -200,23 +200,50 @@ dists = distributions.read_distributions_from_config(cp, opts.dist_section) randomsampler = JointDistribution(variable_params, *dists, **{"constraints" : constraints}) -logging.info("Drawing samples") -samples = randomsampler.rvs(size=opts.ninjections) - -if waveform_transforms is not None: - logging.info("Transforming to waveform transform parameters") - for t in waveform_transforms: - if not set(t.inputs).isdisjoint(set(static_params.keys())): - for item in list((set(t.inputs) & set(static_params.keys()) - - set(samples.fieldnames))): - samples = samples.add_fields([numpy.repeat(static_params[item], - opts.ninjections).astype(float)], - [item]) - samples = transforms.apply_transforms(samples, waveform_transforms) - write_args = [arg for arg in samples.fieldnames - if arg not in static_params.keys()] +if opts.ninjections: + draw_size = opts.ninjections +else: + draw_size = 4000 # Just a default so it's not super slow drawing large sets + old_samples = None + +while True: + logging.info("Drawing samples") + samples = randomsampler.rvs(size=draw_size) + + if waveform_transforms is not None: + logging.info("Transforming to waveform transform parameters") + for t in waveform_transforms: + if not set(t.inputs).isdisjoint(set(static_params.keys())): + for item in list((set(t.inputs) & set(static_params.keys()) + - set(samples.fieldnames))): + samples = samples.add_fields([numpy.repeat(static_params[item], + draw_size).astype(float)], + [item]) + samples = transforms.apply_transforms(samples, waveform_transforms) + + # We are drawing until we hit a time so check if we've reached it + if opts.gps_start_time: + if old_samples is not None: + samples = numpy.concatenate([old_samples, samples]) + old_samples = samples + + # We have enough to cover out time span so we can fixup the times + # and stop here + if samples['tc'].sum() > opts.gps_end_time - opts.gps_start_time: + samples = numpy.sort(samples, order='tc') + samples['tc'] = opts.gps_start_time + samples['tc'].cumsum() + samples = samples[samples['tc'] < opts.gps_end_time] + logging.info('Total Injections: %s', len(samples)) + break + + # We got as many samples as we needed so we can stop + if opts.ninjections and len(samples) >= opts.ninjections: + break # write results logging.info("Writing results") +write_args = [arg for arg in samples.fieldnames + if arg not in static_params.keys()] + InjectionSet.write(opts.output_file, samples, write_args, static_params, cmd=" ".join(sys.argv))