Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add class label option to write metric report to improve readability … #7249

Merged
7 changes: 6 additions & 1 deletion monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def write_metrics_reports(
metrics: dict[str, torch.Tensor | np.ndarray] | None,
metric_details: dict[str, torch.Tensor | np.ndarray] | None,
summary_ops: str | Sequence[str] | None,
class_labels: Sequence[str] | None = None,
deli: str = ",",
output_type: str = "csv",
) -> None:
Expand Down Expand Up @@ -91,6 +92,8 @@ class mean median max 5percentile 95percentile notnans
class1 6.0000 6.0000 6.0000 6.0000 6.0000 1.0000
mean 6.2500 6.2500 7.0000 5.5750 6.9250 2.0000

class_labels: list of class names used to name the classes in the output report, if None,
"class0", ..., "classn" are used, default to None.
deli: the delimiter character in the saved file, default to "," as the default output type is `csv`.
to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
output_type: expected output file type, supported types: ["csv"], default to "csv".
Expand Down Expand Up @@ -118,7 +121,9 @@ class mean median max 5percentile 95percentile notnans
v = v.reshape((-1, 1))

# add the average value of all classes to v
class_labels = ["class" + str(i) for i in range(v.shape[1])] + ["mean"]
if class_labels is None:
class_labels = ["class" + str(i) for i in range(v.shape[1])]
elitap marked this conversation as resolved.
Show resolved Hide resolved
class_labels += ["mean"]
v = np.concatenate([v, np.nanmean(v, axis=1, keepdims=True)], axis=1)

with open(os.path.join(save_dir, f"{k}_raw.csv"), "w") as f:
Expand Down
Loading