Skip to content

Commit

Permalink
Use inference's parameter labels: they are available and mostly good
Browse files Browse the repository at this point in the history
  • Loading branch information
GarethCabournDavies committed Nov 13, 2024
1 parent 643552e commit d06e9fb
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 35 deletions.
58 changes: 28 additions & 30 deletions bin/plotting/pycbc_plot_bank_compression
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import sys
import pycbc
from pycbc.io import HFile
from pycbc.results import save_fig_with_metadata
from pycbc.waveform import bank
from pycbc.inference import option_utils
import pycbc.tmpltbank as tmpltbank

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -60,35 +60,35 @@ parser.add_argument(
help="Flag to indicate that histogram values should be plotted "
"on a log scale"
)
parser.add_argument(
"--comparison-parameter",
default="template_duration:Template duration (s)",
help=(
"Parameter to be compared to in the scatter plot. Supplied "
"as a PARAMETER:NAME pair. PARAMETER should be one of %s, NAME "
"can contain latex symbols. Default template_duration:'Template "
"duration (s)'"
) % ('{' + ', '.join(tmpltbank.conversion_options) + '}')
default_param = "template_duration"
parser.add_argument("--comparison-parameter",
action=option_utils.ParseParametersArg,
metavar="PARAM[:LABEL]",
help="Plot the scatter plot of compressin factor versus the given "
"parameter. Optionally provide a LABEL for use in the plot. "
"Choose from " + ", ".join(tmpltbank.conversion_options) + ", "
"though some options may not be buildable from bank parameters. "
"If no LABEL is provided, PARAM will used as the LABEL. If LABEL "
"is the same as a parameter in pycbc.waveform.parameters, the label "
"property of that parameter will be used. Default: " + default_param
)
args = parser.parse_args()

try:
comparison_parameter, comparison_label = \
args.comparison_parameter.split(':')
except ValueError:
raise parser.error(
"Incorrect format of --comparison-parameter, got %s. See help" %
args.comparison_parameter
)

if comparison_parameter not in tmpltbank.conversion_options:
if args.comparison_parameter is None:
args.comparison_parameter = default_param
args.comparison_parameter_labels = {
default_param: "Template Duration (s)"
}
elif args.comparison_parameter not in tmpltbank.conversion_options:
raise parser.error(
"--comparison-parameter %s not in conversion options %s, see help"
% (comparison_parameter, ', '.join(tmpltbank.conversion_options))
% (args.comparison_parameter, ', '.join(tmpltbank.conversion_options))
)

pycbc.init_logging(args.verbose)

comp_label = args.comparison_parameter_labels[args.comparison_parameter]

# Quieten the matplotlib logger
plt.set_loglevel("info" if args.verbose else "warning")
logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR)
Expand Down Expand Up @@ -121,9 +121,9 @@ for i, bank_fname in enumerate(args.bank_files):
logging.debug("Getting approximants")
# These are in template_id order, so use hash_order to get them back:
approximants += [apx.decode() for apx in bank_f["approximant"][:][hash_order]]
logging.debug("Getting comparison values: %s", comparison_parameter)
logging.debug("Getting comparison values: %s", args.comparison_parameter)
comparison_values += list(tmpltbank.get_bank_property(
comparison_parameter,
args.comparison_parameter,
bank_f,
template_ids=hash_order
))
Expand All @@ -132,10 +132,6 @@ approximants = np.array(approximants)
comparison_values = np.array(comparison_values)
compression_factor = np.array(compression_factor)

print(approximants.size)
print(comparison_values.size)
print(compression_factor.size)

# Store the max/min factors, as these are used for setting
# histogram / plot limits
max_factor = compression_factor.max()
Expand Down Expand Up @@ -241,7 +237,7 @@ else:
axes[0].set_ylabel("Number of Templates")
axes[0].set_xlabel("Compression Factor")

axes[1].set_xlabel(comparison_label)
axes[1].set_xlabel(comp_label)
axes[1].set_ylabel("Compression Factor")

axes[0].legend(loc='upper right')
Expand All @@ -254,14 +250,16 @@ caption = (
"Plot showing the a histogram of compression factor (left) and a "
"scatter plot of compression factor vs %s (%s) (right). "
"Legend entries indicate the number of templates per approximant. "
) % (comparison_label, comparison_parameter)
) % (comp_label, args.comparison_parameter)

if args.histogram_density:
caption += "Density for each histogram is weighted by the number of templates "

logging.info("Saving figure")
save_fig_with_metadata(
fig,
args.output,
title="Bank compression vs %s" % comparison_label,
title="Bank compression vs %s" % comp_label,
caption=caption,
cmd=' '.join(sys.argv)
)
Expand Down
8 changes: 4 additions & 4 deletions bin/workflows/pycbc_make_bank_compression_workflow
Original file line number Diff line number Diff line change
Expand Up @@ -139,23 +139,23 @@ splitbank_files = wf.setup_splittable_dax_generated(
compressed_files = wf.make_compress_split_banks(
workflow,
splitbank_files,
out_dir='compressed_bank',
out_dir='compress_bank',
tags=None,
)

# All the split banks have had the waveforms compressed, so now
# join them back together
rejoined_banks = wf.make_combine_split_banks(
combine_banks = wf.make_combine_split_banks(
workflow,
compressed_files,
out_dir='rejoined_bank',
out_dir='combine_bank',
tags=None,
)

# Make a plot of the compression factor of the templates
plots = wf.make_bank_compression_plots(
workflow,
rejoined_banks,
combine_banks,
out_dir=rdir.base,
tags=None,
)
Expand Down
2 changes: 1 addition & 1 deletion pycbc/workflow/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def make_bank_compression_plots(workflow, bank_files, out_dir, tags=None):
node.new_output_file_opt(
workflow.analysis_time,
'.png',
'--output-file'
'--output'
)
workflow += node
files += node.output_files
Expand Down

0 comments on commit d06e9fb

Please sign in to comment.