Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions test/nodes/test_csv_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import csv
import os
import tempfile

from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase

from torchdata.nodes.csv_reader import CSVReader

from .utils import run_test_save_load_state


class TestCSVReader(TestCase):
def setUp(self):
self.test_data = [
["Alice", "30", "New York"],
["Bob", "25", "London"],
["Charlie", "35", "Paris"],
["David", "40", "Rome"],
["Eve", "45", "Tokyo"],
["Frank", "50", "Beijing"],
["Grace", "55", "Shanghai"],
["Harry", "60", "Seoul"],
["Iris", "65", "Buenos Aires"],
["Jack", "70", "Sao Paulo"],
["Katy", "75", "Mexico City"],
["Lily", "80", "Bogota"],
]

def _create_temp_csv(self, delimiter=",", header=True):
if header:
self.test_data.insert(0, ["name", "age", "city"])
fd, path = tempfile.mkstemp(suffix=".csv")
with os.fdopen(fd, "w", newline="") as f:
writer = csv.writer(f, delimiter=delimiter)
writer.writerows(self.test_data)
return path

def test_basic_read_list(self):
path = self._create_temp_csv(header=False)
node = CSVReader(path, has_header=False)
results = list(node)
self.assertEqual(len(results), len(self.test_data))
self.assertEqual(results[0], ["Alice", "30", "New York"])
self.assertEqual(results[-1], ["Lily", "80", "Bogota"])
node.close()

def test_basic_read_dict(self):
path = self._create_temp_csv()
node = CSVReader(path, has_header=True, return_dict=True)
results = list(node)

self.assertEqual(len(results), len(self.test_data) - 1)
self.assertEqual(results[0], {"name": "Alice", "age": "30", "city": "New York"})
self.assertEqual(results[1]["city"], "London")
self.assertEqual(results[-1]["city"], "Bogota")
node.close()

def test_different_delimiters(self):
path = self._create_temp_csv(delimiter="|")
node = CSVReader(path, has_header=True, delimiter="|", return_dict=True)
results = list(node)

self.assertEqual(len(results), len(self.test_data) - 1)
self.assertEqual(results[2]["city"], "Paris")
self.assertEqual(results[-1]["city"], "Bogota")
node.close()

def test_state_management(self):
path = self._create_temp_csv()
node = CSVReader(path, has_header=True, return_dict=True)
print(f"initial state: {node.state_dict()}")
for _ in range(11):
_ = next(node)
print(f"element = {_}, state: {node.state_dict()}")

state = node.state_dict()

node.reset(state)
item = next(node)

with self.assertRaises(StopIteration):
next(node)

self.assertEqual(item["name"], "Lily")
self.assertEqual(state[CSVReader.NUM_LINES_YIELDED], 11)
node.close()

@parameterized.expand([3, 5, 7])
def test_save_load_state(self, midpoint: int):
path = self._create_temp_csv(header=True)
node = CSVReader(path, has_header=True)
run_test_save_load_state(self, node, midpoint)
node.close()

def test_load_wrong_state(self):
path = self._create_temp_csv(header=True)
node = CSVReader(path, has_header=True)

state = node.state_dict()
state[CSVReader.HEADER_KEY] = None
with self.assertRaisesRegex(
ValueError, "Check if has_header=True matches the state header=None"
):
node.reset(state)

node.close()

node = CSVReader(path, has_header=False)
state = node.state_dict()
state[CSVReader.HEADER_KEY] = ["name", "age"]
with self.assertRaisesRegex(
ValueError,
r"Check if has_header=False matches the state header=\['name', 'age'\]",
):
node.reset(state)

node.close()

def test_empty_file(self):
path = self._create_temp_csv()
# Overwrite with empty file
with open(path, "w") as _:
pass

node = CSVReader(path, has_header=False)
with self.assertRaises(StopIteration):
next(node)
node.close()

def test_header_validation(self):
with self.assertRaisesRegex(
ValueError, "return_dict=True requires has_header=True"
):
CSVReader("dummy.csv", has_header=False, return_dict=True)

def test_multi_epoch(self):
path = self._create_temp_csv()
node = CSVReader(path, has_header=True, return_dict=True)

# First epoch
epoch1 = list(node)
self.assertEqual(len(epoch1), len(self.test_data) - 1)

# Second epoch
node.reset()
epoch2 = list(node)
self.assertEqual(epoch1, epoch2)
node.close()

def test_partial_read_resume(self):
path = self._create_temp_csv(header=True)
node = CSVReader(path, has_header=True)

# Read partial and get state
_ = next(node) # Line 0
state1 = node.state_dict()

_ = next(node) # Line 1
state2 = node.state_dict()

# Resume from first state
node.reset(state1)
self.assertEqual(next(node), self.test_data[2])

# Resume from second state
node.reset(state2)
self.assertEqual(next(node), self.test_data[3])
node.close()

def test_file_closure(self):
path = self._create_temp_csv()
node = CSVReader(path, has_header=True)

# Read all items
list(node)

# Verify file is closed
self.assertTrue(node._file.closed)
node.close()

def test_state_with_header(self):
path = self._create_temp_csv()
node = CSVReader(path, has_header=True, return_dict=True)

# Read one item
_ = next(node)
state = node.state_dict()

# Verify header preservation
node.reset(state)
item = next(node)
self.assertEqual(item["city"], "London")
node.close()

def tearDown(self):
# Clean up temporary files
for f in os.listdir(tempfile.gettempdir()):
if f.startswith("tmp") and f.endswith(".csv"):
os.remove(os.path.join(tempfile.gettempdir(), f))
128 changes: 128 additions & 0 deletions torchdata/nodes/csv_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import csv
from itertools import islice
from typing import Any, Dict, Iterator, List, Optional, Sequence, TextIO, Union

from torchdata.nodes.base_node import BaseNode


class CSVReader(BaseNode[Union[List[str], Dict[str, str]]]):
"""Node for reading CSV files with state management and header support.
Args:
file_path: Path to CSV file
has_header: Whether first row contains column headers
delimiter: CSV field delimiter
return_dict: Return rows as dictionaries (requires has_header=True)
"""

NUM_LINES_YIELDED = "num_lines_yielded"
HEADER_KEY = "header"

def __init__(
self,
file_path: str,
has_header: bool = False,
delimiter: str = ",",
return_dict: bool = False,
encoding: str = "utf-8",
):
super().__init__()
self.file_path = file_path
self.has_header = has_header
self.delimiter = delimiter
self.return_dict = return_dict
if return_dict and not has_header:
raise ValueError("return_dict=True requires has_header=True")
self.encoding = encoding
self._file: Optional[TextIO] = None
self._reader: Optional[Iterator[Union[List[str], Dict[str, str]]]] = None
self._header: Optional[Sequence[str]] = None
self._num_lines_yielded: int = 0
self.reset() # Initialize reader

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset()
self.close()

# Reopen the file and reset counters
self._file = open(self.file_path, encoding=self.encoding)
self._num_lines_yielded = 0
if initial_state is not None:
self._handle_initial_state(initial_state)
else:
self._initialize_reader()

def _handle_initial_state(self, state: Dict[str, Any]):
"""Restore reader state from checkpoint."""
# Validate header compatibility
if (not self.has_header and self.HEADER_KEY in state) or (
self.has_header and state[self.HEADER_KEY] is None
):
raise ValueError(
f"Check if has_header={self.has_header} matches the state header={state[self.HEADER_KEY]}"
)

self._header = state.get(self.HEADER_KEY)
target_line_num = state[self.NUM_LINES_YIELDED]
assert self._file is not None
# Create appropriate reader
if self.return_dict:

self._reader = csv.DictReader(
self._file, delimiter=self.delimiter, fieldnames=self._header
)
else:
self._reader = csv.reader(self._file, delimiter=self.delimiter)
# Skip header if needed (applies only when file has header)

assert isinstance(self._reader, Iterator)
if self.has_header:
try:
next(self._reader) # Skip header line
except StopIteration:
pass # Empty file
# Fast-forward to target line using efficient slicing
consumed = sum(1 for _ in islice(self._reader, target_line_num))
self._num_lines_yielded = consumed

def _initialize_reader(self):
"""Create fresh reader without state."""
assert self._file is not None
if self.return_dict:
self._reader = csv.DictReader(self._file, delimiter=self.delimiter)
self._header = self._reader.fieldnames
else:
self._reader = csv.reader(self._file, delimiter=self.delimiter)

if self.has_header:

try:
self._header = next(self._reader)
except StopIteration:
self._header = None # Handle empty file

def next(self) -> Union[List[str], Dict[str, str]]:
try:
assert isinstance(self._reader, Iterator)
row = next(self._reader)
self._num_lines_yielded += 1
return row

except StopIteration:
self.close()
raise

def get_state(self) -> Dict[str, Any]:
return {
self.NUM_LINES_YIELDED: self._num_lines_yielded,
self.HEADER_KEY: self._header,
}

def close(self):
if self._file is not None and not self._file.closed:
self._file.close()
Loading