From d06e9fbec002ebf36b4203383e5320b266a9ee4e Mon Sep 17 00:00:00 2001 From: GarethCabournDavies Date: Wed, 13 Nov 2024 07:59:12 -0800 Subject: [PATCH] Use inference's parameter labels: they are available and mostly good --- bin/plotting/pycbc_plot_bank_compression | 58 +++++++++---------- .../pycbc_make_bank_compression_workflow | 8 +-- pycbc/workflow/plotting.py | 2 +- 3 files changed, 33 insertions(+), 35 deletions(-) diff --git a/bin/plotting/pycbc_plot_bank_compression b/bin/plotting/pycbc_plot_bank_compression index ab59d698a25..e0d7d2c380b 100644 --- a/bin/plotting/pycbc_plot_bank_compression +++ b/bin/plotting/pycbc_plot_bank_compression @@ -16,7 +16,7 @@ import sys import pycbc from pycbc.io import HFile from pycbc.results import save_fig_with_metadata -from pycbc.waveform import bank +from pycbc.inference import option_utils import pycbc.tmpltbank as tmpltbank parser = argparse.ArgumentParser() @@ -60,35 +60,35 @@ parser.add_argument( help="Flag to indicate that histogram values should be plotted " "on a log scale" ) -parser.add_argument( - "--comparison-parameter", - default="template_duration:Template duration (s)", - help=( - "Parameter to be compared to in the scatter plot. Supplied " - "as a PARAMETER:NAME pair. PARAMETER should be one of %s, NAME " - "can contain latex symbols. Default template_duration:'Template " - "duration (s)'" - ) % ('{' + ', '.join(tmpltbank.conversion_options) + '}') +default_param = "template_duration" +parser.add_argument("--comparison-parameter", + action=option_utils.ParseParametersArg, + metavar="PARAM[:LABEL]", + help="Plot the scatter plot of compressin factor versus the given " + "parameter. Optionally provide a LABEL for use in the plot. " + "Choose from " + ", ".join(tmpltbank.conversion_options) + ", " + "though some options may not be buildable from bank parameters. " + "If no LABEL is provided, PARAM will used as the LABEL. If LABEL " + "is the same as a parameter in pycbc.waveform.parameters, the label " + "property of that parameter will be used. Default: " + default_param ) args = parser.parse_args() -try: - comparison_parameter, comparison_label = \ - args.comparison_parameter.split(':') -except ValueError: - raise parser.error( - "Incorrect format of --comparison-parameter, got %s. See help" % - args.comparison_parameter - ) - -if comparison_parameter not in tmpltbank.conversion_options: +if args.comparison_parameter is None: + args.comparison_parameter = default_param + args.comparison_parameter_labels = { + default_param: "Template Duration (s)" + } +elif args.comparison_parameter not in tmpltbank.conversion_options: raise parser.error( "--comparison-parameter %s not in conversion options %s, see help" - % (comparison_parameter, ', '.join(tmpltbank.conversion_options)) + % (args.comparison_parameter, ', '.join(tmpltbank.conversion_options)) ) pycbc.init_logging(args.verbose) +comp_label = args.comparison_parameter_labels[args.comparison_parameter] + # Quieten the matplotlib logger plt.set_loglevel("info" if args.verbose else "warning") logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR) @@ -121,9 +121,9 @@ for i, bank_fname in enumerate(args.bank_files): logging.debug("Getting approximants") # These are in template_id order, so use hash_order to get them back: approximants += [apx.decode() for apx in bank_f["approximant"][:][hash_order]] - logging.debug("Getting comparison values: %s", comparison_parameter) + logging.debug("Getting comparison values: %s", args.comparison_parameter) comparison_values += list(tmpltbank.get_bank_property( - comparison_parameter, + args.comparison_parameter, bank_f, template_ids=hash_order )) @@ -132,10 +132,6 @@ approximants = np.array(approximants) comparison_values = np.array(comparison_values) compression_factor = np.array(compression_factor) -print(approximants.size) -print(comparison_values.size) -print(compression_factor.size) - # Store the max/min factors, as these are used for setting # histogram / plot limits max_factor = compression_factor.max() @@ -241,7 +237,7 @@ else: axes[0].set_ylabel("Number of Templates") axes[0].set_xlabel("Compression Factor") -axes[1].set_xlabel(comparison_label) +axes[1].set_xlabel(comp_label) axes[1].set_ylabel("Compression Factor") axes[0].legend(loc='upper right') @@ -254,14 +250,16 @@ caption = ( "Plot showing the a histogram of compression factor (left) and a " "scatter plot of compression factor vs %s (%s) (right). " "Legend entries indicate the number of templates per approximant. " -) % (comparison_label, comparison_parameter) +) % (comp_label, args.comparison_parameter) if args.histogram_density: caption += "Density for each histogram is weighted by the number of templates " + +logging.info("Saving figure") save_fig_with_metadata( fig, args.output, - title="Bank compression vs %s" % comparison_label, + title="Bank compression vs %s" % comp_label, caption=caption, cmd=' '.join(sys.argv) ) diff --git a/bin/workflows/pycbc_make_bank_compression_workflow b/bin/workflows/pycbc_make_bank_compression_workflow index 3e635051d7f..86b73e68760 100644 --- a/bin/workflows/pycbc_make_bank_compression_workflow +++ b/bin/workflows/pycbc_make_bank_compression_workflow @@ -139,23 +139,23 @@ splitbank_files = wf.setup_splittable_dax_generated( compressed_files = wf.make_compress_split_banks( workflow, splitbank_files, - out_dir='compressed_bank', + out_dir='compress_bank', tags=None, ) # All the split banks have had the waveforms compressed, so now # join them back together -rejoined_banks = wf.make_combine_split_banks( +combine_banks = wf.make_combine_split_banks( workflow, compressed_files, - out_dir='rejoined_bank', + out_dir='combine_bank', tags=None, ) # Make a plot of the compression factor of the templates plots = wf.make_bank_compression_plots( workflow, - rejoined_banks, + combine_banks, out_dir=rdir.base, tags=None, ) diff --git a/pycbc/workflow/plotting.py b/pycbc/workflow/plotting.py index ee2134c04c2..65e4a0606d5 100644 --- a/pycbc/workflow/plotting.py +++ b/pycbc/workflow/plotting.py @@ -596,7 +596,7 @@ def make_bank_compression_plots(workflow, bank_files, out_dir, tags=None): node.new_output_file_opt( workflow.analysis_time, '.png', - '--output-file' + '--output' ) workflow += node files += node.output_files