Skip to content

Commit

Permalink
feat: update visualiser with Mixtral runs
Browse files Browse the repository at this point in the history
  • Loading branch information
DriesSmit committed May 27, 2024
1 parent a384f98 commit db3cf3c
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 29 deletions.
6 changes: 2 additions & 4 deletions scripts/visualise_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@


# Filter out all runs that where less than 80% completed
runs_table_df = runs_table_df[runs_table_df["eval/percent_complete"] == 100.0]
runs_table_df = runs_table_df[runs_table_df["eval/percent_complete"] >= 100.0]



Expand All @@ -170,9 +170,6 @@
if "few_shot" in key:
print(key)




if chart_type == "bar":
raise NotImplementedError("Bar charts are not implemented yet")
# for metric in tqdm(metrics):
Expand Down Expand Up @@ -379,6 +376,7 @@
"cosmosqa": [],
"ciar": [],
"gpqa": [],
"chess": [],
}
datasets = results.keys()
for unique_id in runs_table_df.index:
Expand Down
99 changes: 79 additions & 20 deletions scripts/visualise_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,13 @@ def latexify(fig_width=7, fig_height=5, columns=1):
"truemed.systems.SingleAgentQA": "Single Agent",
"truemed.systems.SelfConsistency": "Self-Consistency",
"truemed.systems.Medprompt": "Medprompt",
"debatellm.systems.ChatEvalDebate": "ChatEval",
"debatellm.systems.EnsembleRefinementDebate": "Ensemble Refinement",
"debatellm.systems.MultiAgentDebateGoogle": "Society of Mind",
"debatellm.systems.MultiAgentDebateTsinghua": "Multi-Persona",
"debatellm.systems.SingleAgentQA": "Single Agent",
"debatellm.systems.SelfConsistency": "Self-Consistency",
"debatellm.systems.Medprompt": "Medprompt",
}

marker_dict = {
Expand All @@ -145,6 +152,7 @@ def latexify(fig_width=7, fig_height=5, columns=1):
"ciar": "CIAR",
"cosmosqa": "CosmosQA",
"gpqa": "GPQA",
"chess": "Chess",
}

muted_color_dict = {
Expand Down Expand Up @@ -510,9 +518,16 @@ def get_paper_dataset_ranges(dataset):
# ]

# GPT-4 results
# run_range = [
# [4091, 4199],
# ]

# Mixtral results
run_range = [
[4091, 4199],
[4555, 4620],
[4627, 5000],
]

elif dataset == "pubmedqa":
run_range = [
[3832, 3885],
Expand All @@ -528,6 +543,10 @@ def get_paper_dataset_ranges(dataset):
[3904, 4058],
[4116, 4163],
]
elif dataset == "chess":
run_range = [
[4424, 4599],
]
return run_range


Expand All @@ -537,26 +556,30 @@ def filter_results_for_paper(runs_table_df, dataset):
dataset = dataset.lower()

# Filter out GPT-3.5
runs_table_df = runs_table_df[
runs_table_df["config/system/agents/Agent_0/engine"] != "chat-bison@001"
]
# runs_table_df = runs_table_df[
# runs_table_df["config/system/agents/Agent_0/engine"] != "chat-bison@001"
# ]

# Filter out PaLM
runs_table_df = runs_table_df[
runs_table_df["config/system/agents/Agent_0/engine"] != "chat-bison@001"
]


# Filter out GPT-4
runs_table_df = runs_table_df[
runs_table_df["config/system/agents/Agent_0/engine"] != "gpt-4"
]

# Filter out Mixtral
# runs_table_df = runs_table_df[
# runs_table_df["config/system/agents/Agent_0/engine"] != "gpt-4"
# runs_table_df["config/system/agents/Agent_0/engine"] != "mixtral-8x7b-instruct"
# ]

# Filter out Multi-Persona agreement intensities greater than 0
# Filter out runs where 'ai' is greater than -1, ignoring NaN values

runs_table_df = runs_table_df[
(runs_table_df["config/system/agreement_intensity"] == 9)
(runs_table_df["config/system/agreement_intensity"] == 9) # Use -1 or 9
| (runs_table_df["config/system/_target_"] != "truemed.systems.MultiAgentDebateTsinghua")
]

Expand Down Expand Up @@ -616,6 +639,9 @@ def get_dataset_runs(run_range: List[int], dataset: str = None, engine: str = No
def fetch_runs_in_batches(run_range, batch_size=1000):
all_dataframes = []
start, end = run_range

# Include the final result as well.
end = end + 1
for batch_start in range(start, end, batch_size):
batch_end = min(batch_start + batch_size, end)
batch_ids = [f"TRUEM-{run_id}" for run_id in range(batch_start, batch_end)]
Expand All @@ -637,6 +663,9 @@ def fetch_runs_in_batches(run_range, batch_size=1000):
# runs_table_df = runs_table_df[runs_table_df["eval/percent_complete"] >= 40.0]
runs_table_df = runs_table_df[runs_table_df["eval/percent_complete"] >= 100.0]




# At least 10 evaluations
runs_table_df["config/max_eval_count"] = pd.to_numeric(
runs_table_df["config/max_eval_count"], errors="coerce"
Expand All @@ -648,7 +677,6 @@ def fetch_runs_in_batches(run_range, batch_size=1000):

runs_table_df = filter_results_for_paper(runs_table_df, dataset)


if dataset is not None:
# Filter out all runs that are not from the current dataset
runs_table_df = runs_table_df[
Expand All @@ -667,15 +695,6 @@ def fetch_runs_in_batches(run_range, batch_size=1000):
[key for key in runs_table_df.keys() if "eval/" in key or "config/" in key]
]

name_mapping = {
"truemed.systems.ChatEvalDebate": "ChatEval",
"truemed.systems.EnsembleRefinementDebate": "Ensemble Refinement",
"truemed.systems.MultiAgentDebateGoogle": "Society of Mind",
"truemed.systems.MultiAgentDebateTsinghua": "Multi-Persona",
"truemed.systems.SingleAgentQA": "Single Agent",
"truemed.systems.Medprompt": "Medprompt",
}

runs_table_df = runs_table_df[runs_table_df["config/system/_target_"].notna()]
runs_table_df["system_name"] = runs_table_df["config/system/_target_"].apply(
lambda x: name_mapping.get(x, x)
Expand Down Expand Up @@ -752,8 +771,32 @@ def get_scatter_plot(metrics, runs_table_df, legend=True, dataset="", save_path=
hue_order=system_order,
data=runs_table_df_plot,
)
plt.xscale("log")
plt.xlabel(f"{metrics} (Log Scale)")

# Add a point that indicates what we achieved when improving Tshinghua

# # Add a point for a specific protocol
# point_value = 0.665 # Replace with the actual point value

# # Get x-coordinate of the protocol
# protocol_x_coord = 11.5

# plt.scatter(
# x=[protocol_x_coord], # x-coordinate of the box
# y=[point_value], # y-coordinate (point value)
# color="red",
# s=100, # Size of the point
# label='Improved Multi-Persona', # Label for legend
# zorder=5, # Make sure point appears on top
# marker='x' # Diamond shape
# )

plt.xlabel("")
plt.ylabel("Accuracy Accuracy (out of 1.0)")
# plt.title("Total Accuracy by System")


# plt.xscale("log")
plt.xlabel(f"{metrics}")
plt.ylabel(f"{dataset} Accuracy (out of 1.0)")
if legend:
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
Expand Down Expand Up @@ -945,7 +988,8 @@ def generate_box_chart(
if system in all_sorted_systems:
sorted_systems.append(system)

# Create the boxplot
print("sorted_systems:", sorted_systems)

# Create the boxplot
sns.boxplot(
x=runs_table_df["config/system/_target_"],
Expand Down Expand Up @@ -1013,8 +1057,22 @@ def generate_box_chart(
plt.savefig(save_path, bbox_inches="tight")
plt.show()

# Calculate the best and median score for each system
best_scores = runs_table_df.groupby("config/system/_target_")[y_metric].max()
median_scores = runs_table_df.groupby("config/system/_target_")[y_metric].median()

# Print the best and median scores
print("Best scores:")
print(best_scores)
print("\nMedian scores:")
print(median_scores)


def create_num_rounds_column(df):
# Initialize 'num_rounds' with NaN if it doesn't exist
if "config/system/num_rounds" not in df.columns and "config/system/max_num_rounds" in df.columns:
df["num_rounds"] = pd.NA

# Use 'config/system/num_rounds' if it exists and is not NaN
if "config/system/num_rounds" in df.columns:
df["num_rounds"] = df["config/system/num_rounds"]
Expand All @@ -1023,6 +1081,7 @@ def create_num_rounds_column(df):
if "config/system/max_num_rounds" in df.columns:
df["num_rounds"].fillna(df["config/system/max_num_rounds"], inplace=True)


# Use 1 if both columns are NaN or don't exist
df["num_rounds"].fillna(1, inplace=True)

Expand Down
15 changes: 10 additions & 5 deletions scripts/vizualize_results_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
# limitations under the License.

from visualise_utils import get_dataset_runs, get_paper_dataset_ranges, get_scatter_plot
import pandas as pd

DATASET = "MedQA" # MedQA, USMLE, PubMedQA, MMLU, CosmosQA, Ciar, GPQA,
DATASET = "MedQA" # MedQA, USMLE, PubMedQA, MMLU, CosmosQA, Ciar, GPQA, Chess
# RUN_RANGE = [3904, 4058]

# Hardcode the ranges based on the names
Expand All @@ -44,10 +45,14 @@
SAVE_PATH = f"./data/charts/{DATASET}_{METRIC}_scatter_plots.pdf" # None if you don't want to

LEGEND = True # whether to show the legend or not

assert len(RUN_RANGE) == 1, "Please provide the correct range for the dataset"

run_table = get_dataset_runs(RUN_RANGE[0], DATASET.lower())

run_table = None

for run_range in RUN_RANGE:
if run_table is None:
run_table = get_dataset_runs(run_range, DATASET.lower())
else:
run_table = pd.concat([run_table, get_dataset_runs(run_range, DATASET.lower())], ignore_index=True)

print(f"Number of runs: {len(run_table)}")

Expand Down

0 comments on commit db3cf3c

Please sign in to comment.