Skip to content

Commit

Permalink
extend pycbc_create_injections to work for search workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
ahnitz committed Aug 26, 2020
1 parent aff4493 commit dce6e91
Showing 1 changed file with 44 additions and 17 deletions.
61 changes: 44 additions & 17 deletions bin/pycbc_create_injections
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ parser.add_argument('--ninjections', type=int,
parser.add_argument('--gps-start-time', type=int, help="Alternative to "
"ninjections argument. Injections will be distributed "
"in time by taking the cumulative sum of the tc values. "
"The number will be chosen to fill the chosen time range. "
"The number will be chosen to fill the chosen time range. ")
parser.add_argument('--gps-end-time', type=int, help="Alternative to "
"ninjections argument. Injections will be distributed "
"in time by taking the cumulative sum of the tc values. "
Expand Down Expand Up @@ -169,7 +169,7 @@ pycbc.init_logging(opts.verbose)
if os.path.exists(opts.output_file) and not opts.force:
raise OSError("output-file already exists; use --force if you wish to "
"overwrite it.")

if opts.ninjections and (opts.gps_start_time or opts.gps_end_time):
raise ValueError("Cannot provide both ninjections and start/end time.")

Expand Down Expand Up @@ -200,23 +200,50 @@ dists = distributions.read_distributions_from_config(cp, opts.dist_section)
randomsampler = JointDistribution(variable_params, *dists,
**{"constraints" : constraints})

logging.info("Drawing samples")
samples = randomsampler.rvs(size=opts.ninjections)

if waveform_transforms is not None:
logging.info("Transforming to waveform transform parameters")
for t in waveform_transforms:
if not set(t.inputs).isdisjoint(set(static_params.keys())):
for item in list((set(t.inputs) & set(static_params.keys())
- set(samples.fieldnames))):
samples = samples.add_fields([numpy.repeat(static_params[item],
opts.ninjections).astype(float)],
[item])
samples = transforms.apply_transforms(samples, waveform_transforms)
write_args = [arg for arg in samples.fieldnames
if arg not in static_params.keys()]
if opts.ninjections:
draw_size = opts.ninjections
else:
draw_size = 4000 # Just a default so it's not super slow drawing large sets
old_samples = None

while True:
logging.info("Drawing samples")
samples = randomsampler.rvs(size=draw_size)

if waveform_transforms is not None:
logging.info("Transforming to waveform transform parameters")
for t in waveform_transforms:
if not set(t.inputs).isdisjoint(set(static_params.keys())):
for item in list((set(t.inputs) & set(static_params.keys())
- set(samples.fieldnames))):
samples = samples.add_fields([numpy.repeat(static_params[item],
draw_size).astype(float)],
[item])
samples = transforms.apply_transforms(samples, waveform_transforms)

# We are drawing until we hit a time so check if we've reached it
if opts.gps_start_time:
if old_samples is not None:
samples = numpy.concatenate([old_samples, samples])
old_samples = samples

# We have enough to cover out time span so we can fixup the times
# and stop here
if samples['tc'].sum() > opts.gps_end_time - opts.gps_start_time:
samples = numpy.sort(samples, order='tc')
samples['tc'] = opts.gps_start_time + samples['tc'].cumsum()
samples = samples[samples['tc'] < opts.gps_end_time]
logging.info('Total Injections: %s', len(samples))
break

# We got as many samples as we needed so we can stop
if opts.ninjections and len(samples) >= opts.ninjections:
break

# write results
logging.info("Writing results")
write_args = [arg for arg in samples.fieldnames
if arg not in static_params.keys()]

InjectionSet.write(opts.output_file, samples, write_args, static_params,
cmd=" ".join(sys.argv))

0 comments on commit dce6e91

Please sign in to comment.