Skip to content

Commit

Permalink
Fix streaming parquet with image feature in schema (#5921)
Browse files Browse the repository at this point in the history
* fix streaming parquet with image feature in schema

* minor
  • Loading branch information
lhoestq authored Jun 2, 2023
1 parent 074925b commit 7e52021
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
16 changes: 11 additions & 5 deletions src/datasets/packaged_modules/parquet/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,28 @@ def _split_generators(self, dl_manager):
files = [files]
# Use `dl_manager.iter_files` to skip hidden files in an extracted archive
files = [dl_manager.iter_files(file) for file in files]
# Infer features is they are stoed in the arrow schema
if self.info.features is None:
for file in itertools.chain.from_iterable(files):
with open(file, "rb") as f:
self.info.features = datasets.Features.from_arrow_schema(pq.read_schema(f))
break
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits

def _cast_table(self, pa_table: pa.Table) -> pa.Table:
if self.config.features is not None:
if self.info.features is not None:
# more expensive cast to support nested features with keys in a different order
# allows str <-> int/float or str to Audio for example
pa_table = table_cast(pa_table, self.config.features.arrow_schema)
pa_table = table_cast(pa_table, self.info.features.arrow_schema)
return pa_table

def _generate_tables(self, files):
schema = self.config.features.arrow_schema if self.config.features is not None else None
if self.config.features is not None and self.config.columns is not None:
schema = self.info.features.arrow_schema if self.info.features is not None else None
if self.info.features is not None and self.config.columns is not None:
if sorted(field.name for field in schema) != sorted(self.config.columns):
raise ValueError(
f"Tried to load parquet data with columns '{self.config.columns}' with mismatching features '{self.config.features}'"
f"Tried to load parquet data with columns '{self.config.columns}' with mismatching features '{self.info.features}'"
)
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
with open(file, "rb") as f:
Expand Down
4 changes: 3 additions & 1 deletion tests/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,5 +142,7 @@ def test_dataset_to_parquet_keeps_features(shared_datadir, tmp_path):
assert writer.write() > 0

reloaded_dataset = Dataset.from_parquet(str(tmp_path / "foo.parquet"))

assert dataset.features == reloaded_dataset.features

reloaded_iterable_dataset = ParquetDatasetReader(str(tmp_path / "foo.parquet"), streaming=True).read()
assert dataset.features == reloaded_iterable_dataset.features

0 comments on commit 7e52021

Please sign in to comment.