Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add class label option to write metric report to improve readability … #7249

Merged
10 changes: 9 additions & 1 deletion monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def write_metrics_reports(
summary_ops: str | Sequence[str] | None,
deli: str = ",",
output_type: str = "csv",
class_labels: list[str] | None = None,
) -> None:
"""
Utility function to write the metrics into files, contains 3 parts:
Expand Down Expand Up @@ -94,6 +95,8 @@ class mean median max 5percentile 95percentile notnans
deli: the delimiter character in the saved file, default to "," as the default output type is `csv`.
to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
output_type: expected output file type, supported types: ["csv"], default to "csv".
class_labels: list of class names used to name the classes in the output report, if None,
"class0", ..., "classn" are used, default to None.
"""
if output_type.lower() != "csv":
Expand All @@ -118,7 +121,12 @@ class mean median max 5percentile 95percentile notnans
v = v.reshape((-1, 1))

# add the average value of all classes to v
class_labels = ["class" + str(i) for i in range(v.shape[1])] + ["mean"]
if class_labels is None:
class_labels = ["class" + str(i) for i in range(v.shape[1])]
elitap marked this conversation as resolved.
Show resolved Hide resolved
else:
class_labels = [str(i) for i in class_labels] # ensure to have a list of str

class_labels += ["mean"]
v = np.concatenate([v, np.nanmean(v, axis=1, keepdims=True)], axis=1)

with open(os.path.join(save_dir, f"{k}_raw.csv"), "w") as f:
Expand Down
Loading