Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Sep 15, 2023
1 parent 158442a commit 2807139
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
24 changes: 24 additions & 0 deletions tests/tests_fabric/loggers/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest
import torch
import csv

from lightning.fabric.loggers import CSVLogger
from lightning.fabric.loggers.csv_logs import _ExperimentWriter
Expand Down Expand Up @@ -117,3 +118,26 @@ def test_automatic_step_tracking(tmp_path):
logger.log_metrics(metrics, step=None)
logger.save.assert_called_once()
assert logger.experiment.metrics[2]["step"] == 2


def test_append_columns(tmp_path):
"""Test that the CSV file gets rewritten with new headers if the columns change."""
logger = CSVLogger(tmp_path, flush_logs_every_n_steps=1)

# initial metrics
metrics = {"a": 1, "b": 2}
logger.log_metrics(metrics)

# new key appears
metrics = {"a": 1, "b": 2, "c": 3}
logger.log_metrics(metrics)
with open(logger.experiment.metrics_file_path, "r") as file:
header = file.readline().strip()
assert set(header.split(",")) == {"step", "a", "b", "c"}

# key disappears
metrics = {"a": 1, "c": 3}
logger.log_metrics(metrics)
with open(logger.experiment.metrics_file_path, "r") as file:
header = file.readline().strip()
assert set(header.split(",")) == {"step", "a", "b", "c"}
23 changes: 23 additions & 0 deletions tests/tests_pytorch/loggers/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,26 @@ def test_flush_n_steps(tmpdir):
logger.save.assert_not_called()
logger.log_metrics(metrics, step=1)
logger.save.assert_called_once()


def test_append_columns(tmp_path):
"""Test that the CSV file gets rewritten with new headers if the columns change."""
logger = CSVLogger(tmp_path, flush_logs_every_n_steps=1)

# initial metrics
metrics = {"a": 1, "b": 2}
logger.log_metrics(metrics)

# new key appears
metrics = {"a": 1, "b": 2, "c": 3}
logger.log_metrics(metrics)
with open(logger.experiment.metrics_file_path, "r") as file:
header = file.readline().strip()
assert set(header.split(",")) == {"step", "a", "b", "c"}

# key disappears
metrics = {"a": 1, "c": 3}
logger.log_metrics(metrics)
with open(logger.experiment.metrics_file_path, "r") as file:
header = file.readline().strip()
assert set(header.split(",")) == {"step", "a", "b", "c"}

0 comments on commit 2807139

Please sign in to comment.