From cb834d9c63ab8cb14725ae8e4fc2da8672892a6d Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 28 Feb 2024 07:39:12 +0100 Subject: [PATCH] Make JSON builder support an array of strings (#6696) * Test JSON builder with list of strings * Make JSON builder support array of strings --- src/datasets/packaged_modules/json/json.py | 16 +++++++++++---- tests/packaged_modules/test_json.py | 24 +++++++++++++++++++++- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index 0008d45564c..4c017a642f6 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -145,12 +145,20 @@ def _generate_tables(self, files): except json.JSONDecodeError: logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") raise e - # If possible, parse the file as a list of json objects and exit the loop + # If possible, parse the file as a list of json objects/strings and exit the loop if isinstance(dataset, list): # list is the only sequence type supported in JSON try: - keys = set().union(*[row.keys() for row in dataset]) - mapping = {col: [row.get(col) for row in dataset] for col in keys} - pa_table = pa.Table.from_pydict(mapping) + if dataset and isinstance(dataset[0], str): + pa_table_names = ( + list(self.config.features) + if self.config.features is not None + else ["text"] + ) + pa_table = pa.Table.from_arrays([pa.array(dataset)], names=pa_table_names) + else: + keys = set().union(*[row.keys() for row in dataset]) + mapping = {col: [row.get(col) for row in dataset] for col in keys} + pa_table = pa.Table.from_pydict(mapping) except (pa.ArrowInvalid, AttributeError) as e: logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") raise ValueError(f"Not able to read records in the JSON file at {file}.") from None diff --git a/tests/packaged_modules/test_json.py b/tests/packaged_modules/test_json.py index 43bd90ac98b..9375fd443b4 100644 --- a/tests/packaged_modules/test_json.py +++ b/tests/packaged_modules/test_json.py @@ -54,6 +54,23 @@ def json_file_with_list_of_dicts(tmp_path): return str(filename) +@pytest.fixture +def json_file_with_list_of_strings(tmp_path): + filename = tmp_path / "file_with_list_of_strings.json" + data = textwrap.dedent( + """\ + [ + "First text.", + "Second text.", + "Third text." + ] + """ + ) + with open(filename, "w") as f: + f.write(data) + return str(filename) + + @pytest.fixture def json_file_with_list_of_dicts_field(tmp_path): filename = tmp_path / "file_with_list_of_dicts_field.json" @@ -82,13 +99,18 @@ def json_file_with_list_of_dicts_field(tmp_path): ("jsonl_file_utf16_encoded", {"encoding": "utf-16"}), ("json_file_with_list_of_dicts", {}), ("json_file_with_list_of_dicts_field", {"field": "field3"}), + ("json_file_with_list_of_strings", {}), ], ) def test_json_generate_tables(file_fixture, config_kwargs, request): json = Json(**config_kwargs) generator = json._generate_tables([[request.getfixturevalue(file_fixture)]]) pa_table = pa.concat_tables([table for _, table in generator]) - assert pa_table.to_pydict() == {"col_1": [-1, 1, 10], "col_2": [None, 2, 20]} + if file_fixture == "json_file_with_list_of_strings": + expected = {"text": ["First text.", "Second text.", "Third text."]} + else: + expected = {"col_1": [-1, 1, 10], "col_2": [None, 2, 20]} + assert pa_table.to_pydict() == expected @pytest.mark.parametrize(