diff --git a/tests/tests_fabric/loggers/test_csv.py b/tests/tests_fabric/loggers/test_csv.py index d1a64031a2225..17c1ba02ad61f 100644 --- a/tests/tests_fabric/loggers/test_csv.py +++ b/tests/tests_fabric/loggers/test_csv.py @@ -16,6 +16,7 @@ import pytest import torch +import csv from lightning.fabric.loggers import CSVLogger from lightning.fabric.loggers.csv_logs import _ExperimentWriter @@ -117,3 +118,26 @@ def test_automatic_step_tracking(tmp_path): logger.log_metrics(metrics, step=None) logger.save.assert_called_once() assert logger.experiment.metrics[2]["step"] == 2 + + +def test_append_columns(tmp_path): + """Test that the CSV file gets rewritten with new headers if the columns change.""" + logger = CSVLogger(tmp_path, flush_logs_every_n_steps=1) + + # initial metrics + metrics = {"a": 1, "b": 2} + logger.log_metrics(metrics) + + # new key appears + metrics = {"a": 1, "b": 2, "c": 3} + logger.log_metrics(metrics) + with open(logger.experiment.metrics_file_path, "r") as file: + header = file.readline().strip() + assert set(header.split(",")) == {"step", "a", "b", "c"} + + # key disappears + metrics = {"a": 1, "c": 3} + logger.log_metrics(metrics) + with open(logger.experiment.metrics_file_path, "r") as file: + header = file.readline().strip() + assert set(header.split(",")) == {"step", "a", "b", "c"} diff --git a/tests/tests_pytorch/loggers/test_csv.py b/tests/tests_pytorch/loggers/test_csv.py index 36e70a4dbdf6a..8480826129f88 100644 --- a/tests/tests_pytorch/loggers/test_csv.py +++ b/tests/tests_pytorch/loggers/test_csv.py @@ -136,3 +136,26 @@ def test_flush_n_steps(tmpdir): logger.save.assert_not_called() logger.log_metrics(metrics, step=1) logger.save.assert_called_once() + + +def test_append_columns(tmp_path): + """Test that the CSV file gets rewritten with new headers if the columns change.""" + logger = CSVLogger(tmp_path, flush_logs_every_n_steps=1) + + # initial metrics + metrics = {"a": 1, "b": 2} + logger.log_metrics(metrics) + + # new key appears + metrics = {"a": 1, "b": 2, "c": 3} + logger.log_metrics(metrics) + with open(logger.experiment.metrics_file_path, "r") as file: + header = file.readline().strip() + assert set(header.split(",")) == {"step", "a", "b", "c"} + + # key disappears + metrics = {"a": 1, "c": 3} + logger.log_metrics(metrics) + with open(logger.experiment.metrics_file_path, "r") as file: + header = file.readline().strip() + assert set(header.split(",")) == {"step", "a", "b", "c"}