diff --git a/package/scripts/prmon_plot.py b/package/scripts/prmon_plot.py index 6d2eb86..5baa1f5 100755 --- a/package/scripts/prmon_plot.py +++ b/package/scripts/prmon_plot.py @@ -121,6 +121,64 @@ def get_multiplier(label, unit): return MULTIPLIERS[ALLOWEDUNITS[label][0].upper()] / MULTIPLIERS[unit] +# Function for checking the input file exists +def check_input_file(file): + if not os.path.exists(file): + print(f"ERROR:: Input file {file} does not exist") + sys.exit(-1) + + +# Function for loading the data +def load_data(file): + data = pd.read_csv(file, sep="\t") + data["Time"] = pd.to_datetime(data["Time"], unit="s") + return data + + +# Function for checking whether the variables are in the data +def check_variables(data, var, ylist): + if var not in list(data): + print(f"ERROR:: Variable {var} is not available in one of the data sets") + sys.exit(-1) + for carg in ylist: + if carg not in list(data): + print(f"ERROR:: Variable {carg} is not available in one of the data sets") + + +# This function creates the final data set of y-values +def make_list(ylist, data, args, xmult, ymult, xlabel): + ydlist = [] + for carg in ylist: + if args.diff: + num = np.array(data[carg].diff()) * ymult + denom = np.array(data[xlabel].diff()) * xmult + ratio = np.where(denom != 0, num / denom, np.nan) + ydlist.append(ratio) + else: + ydlist.append(np.array(data[carg]) * ymult) + return ydlist + + +# Graph plotting functions +def draw_stacked_graph(xdata, ydlist, ylist): + ydata = np.vstack(ydlist) + plt.stackplot( + xdata, ydata, lw=2, labels=[LEGENDNAMES[val] for val in ylist], alpha=0.6 + ) + + +def draw_line_graph(xdata, ydlist, ylist, sty, inputs, count): + # This is a list of the matplotlib default colours + colours = plt.rcParams["axes.prop_cycle"].by_key()["color"] + for cidx, cdata in enumerate(ydlist): + if len(inputs) == 1: + lbl = LEGENDNAMES[ylist[cidx]] + plt.plot(xdata, cdata, lw=2, label=lbl, color=colours[cidx], linestyle=sty) + else: + lbl = f"{LEGENDNAMES[ylist[cidx]]} ({inputs[count]})" + plt.plot(xdata, cdata, lw=2, label=lbl, color=colours[cidx], linestyle=sty) + + def main(): """prmon plotting main function""" @@ -134,7 +192,8 @@ def main(): "--input", type=str, default="prmon.txt", - help="PrMon TXT output that will be used as input", + help="PrMon TXT output(s) that will be used as input(s)" + " (comma separated list is accepted)", ) parser.add_argument( "--output", @@ -160,7 +219,7 @@ def main(): type=str, default=default_yvar, help="name(s) of the variable(s) to be plotted in the y-axis" - " (comma seperated list is accepted)", + " (comma separated list is accepted)", ) parser.add_argument( "--yunit", @@ -193,24 +252,14 @@ def main(): parser.set_defaults(diff=False) args = parser.parse_args() - # Check the input file exists - if not os.path.exists(args.input): - print(f"ERROR:: Input file {args.input} does not exists") - sys.exit(-1) - - # Load the data - data = pd.read_csv(args.input, sep="\t") - data["Time"] = pd.to_datetime(data["Time"], unit="s") - - # Check the variables are in data - if args.xvar not in list(data): - print(f"ERROR:: Variable {args.xvar} is not available in data") - sys.exit(-1) + inputs = args.input.split(",") ylist = args.yvar.split(",") - for carg in ylist: - if carg not in list(data): - print(f"ERROR:: Variable {carg} is not available in data") - sys.exit(-1) + data = [] + + for i in range(len(inputs)): + check_input_file(inputs[i]) + data.append(load_data(inputs[i])) + check_variables(data[i], args.xvar, ylist) # Check the consistency of variables and units # If they don't match, reset the units to defaults @@ -257,24 +306,31 @@ def main(): # Here comes the figure and data extraction fig, ax1 = plt.subplots() - xdata = np.array(data[xlabel]) * xmultiplier + + xdata = [] ydlist = [] - for carg in ylist: - if args.diff: - num = np.array(data[carg].diff()) * ymultiplier - denom = np.array(data[xlabel].diff()) * xmultiplier - ratio = np.where(denom != 0, num / denom, np.nan) - ydlist.append(ratio) - else: - ydlist.append(np.array(data[carg]) * ymultiplier) + + for i in range(len(data)): + xdata.append(np.array(data[i][xlabel] * xmultiplier)) + ydlist.append(make_list(ylist, data[i], args, xmultiplier, ymultiplier, xlabel)) + + # Plot the graphs + line_styles = list(mpl.lines.lineStyles.keys()) + if args.stacked: - ydata = np.vstack(ydlist) - plt.stackplot( - xdata, ydata, lw=2, labels=[LEGENDNAMES[val] for val in ylist], alpha=0.6 - ) + if len(inputs) == 1: + for i in range(len(xdata)): + draw_stacked_graph(xdata[i], ydlist[i], ylist) + else: + print("ERROR:: Stacked graphs are not supported for more than one data set") + sys.exit(-1) else: - for cidx, cdata in enumerate(ydlist): - plt.plot(xdata, cdata, lw=2, label=LEGENDNAMES[ylist[cidx]]) + for i in range(len(xdata)): + draw_line_graph( + xdata[i], ydlist[i], ylist, line_styles[i % len(line_styles)], inputs, i + ) + + # Create the key plt.legend(loc=0) if "Time" in xlabel: formatter = mpl.dates.DateFormatter("%H:%M:%S") @@ -290,6 +346,7 @@ def main(): else: fylabel = get_axis_label(ylist[0]) fyunit = args.yunit + plt.title("Plot of {} vs {}".format(fxlabel, fylabel), y=1.05) plt.xlabel((fxlabel + " [" + fxunit + "]") if fxunit != "1" else fxlabel) plt.ylabel((fylabel + " [" + fyunit + "]") if fyunit != "1" else fylabel)