Skip to content

Commit

Permalink
Make JSON builder support an array of strings (#6696)
Browse files Browse the repository at this point in the history
* Test JSON builder with list of strings

* Make JSON builder support array of strings
  • Loading branch information
albertvillanova authored Feb 28, 2024
1 parent 9c424fa commit cb834d9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
16 changes: 12 additions & 4 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,20 @@ def _generate_tables(self, files):
except json.JSONDecodeError:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise e
# If possible, parse the file as a list of json objects and exit the loop
# If possible, parse the file as a list of json objects/strings and exit the loop
if isinstance(dataset, list): # list is the only sequence type supported in JSON
try:
keys = set().union(*[row.keys() for row in dataset])
mapping = {col: [row.get(col) for row in dataset] for col in keys}
pa_table = pa.Table.from_pydict(mapping)
if dataset and isinstance(dataset[0], str):
pa_table_names = (
list(self.config.features)
if self.config.features is not None
else ["text"]
)
pa_table = pa.Table.from_arrays([pa.array(dataset)], names=pa_table_names)
else:
keys = set().union(*[row.keys() for row in dataset])
mapping = {col: [row.get(col) for row in dataset] for col in keys}
pa_table = pa.Table.from_pydict(mapping)
except (pa.ArrowInvalid, AttributeError) as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise ValueError(f"Not able to read records in the JSON file at {file}.") from None
Expand Down
24 changes: 23 additions & 1 deletion tests/packaged_modules/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,23 @@ def json_file_with_list_of_dicts(tmp_path):
return str(filename)


@pytest.fixture
def json_file_with_list_of_strings(tmp_path):
filename = tmp_path / "file_with_list_of_strings.json"
data = textwrap.dedent(
"""\
[
"First text.",
"Second text.",
"Third text."
]
"""
)
with open(filename, "w") as f:
f.write(data)
return str(filename)


@pytest.fixture
def json_file_with_list_of_dicts_field(tmp_path):
filename = tmp_path / "file_with_list_of_dicts_field.json"
Expand Down Expand Up @@ -82,13 +99,18 @@ def json_file_with_list_of_dicts_field(tmp_path):
("jsonl_file_utf16_encoded", {"encoding": "utf-16"}),
("json_file_with_list_of_dicts", {}),
("json_file_with_list_of_dicts_field", {"field": "field3"}),
("json_file_with_list_of_strings", {}),
],
)
def test_json_generate_tables(file_fixture, config_kwargs, request):
json = Json(**config_kwargs)
generator = json._generate_tables([[request.getfixturevalue(file_fixture)]])
pa_table = pa.concat_tables([table for _, table in generator])
assert pa_table.to_pydict() == {"col_1": [-1, 1, 10], "col_2": [None, 2, 20]}
if file_fixture == "json_file_with_list_of_strings":
expected = {"text": ["First text.", "Second text.", "Third text."]}
else:
expected = {"col_1": [-1, 1, 10], "col_2": [None, 2, 20]}
assert pa_table.to_pydict() == expected


@pytest.mark.parametrize(
Expand Down

0 comments on commit cb834d9

Please sign in to comment.