diff --git a/python/pyarrow/tests/test_types.py b/python/pyarrow/tests/test_types.py index aecf32c5076be..2dce1e613d6a6 100644 --- a/python/pyarrow/tests/test_types.py +++ b/python/pyarrow/tests/test_types.py @@ -16,7 +16,7 @@ # under the License. from collections import OrderedDict -from collections.abc import Iterator +from collections.abc import Iterator, Mapping from functools import partial import datetime import sys @@ -1325,17 +1325,36 @@ def test_types_come_back_with_specific_type(): assert type(type_back) is type(arrow_type) -def test_schema_import_c_schema_interface(): - class Wrapper: - def __init__(self, schema): - self.schema = schema +class SchemaWrapper: + def __init__(self, schema): + self.schema = schema + + def __arrow_c_schema__(self): + return self.schema.__arrow_c_schema__() + + +class SchemaMapping(Mapping): + def __init__(self, schema): + self.schema = schema + + def __arrow_c_schema__(self): + return self.schema.__arrow_c_schema__() + + def __getitem__(self, key): + return self.schema[key] + + def __iter__(self): + return iter(self.schema) + + def __len__(self): + return len(self.schema) - def __arrow_c_schema__(self): - return self.schema.__arrow_c_schema__() +@pytest.mark.parametrize("wrapper_class", [SchemaWrapper, SchemaMapping]) +def test_schema_import_c_schema_interface(wrapper_class): schema = pa.schema([pa.field("field_name", pa.int32())], metadata={"a": "b"}) assert schema.metadata == {b"a": b"b"} - wrapped_schema = Wrapper(schema) + wrapped_schema = wrapper_class(schema) assert pa.schema(wrapped_schema) == schema assert pa.schema(wrapped_schema).metadata == {b"a": b"b"} diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 4343d7ea300b0..ab20d74dd997e 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -5347,14 +5347,15 @@ def schema(fields, metadata=None): Field py_field vector[shared_ptr[CField]] c_fields - if isinstance(fields, Mapping): - fields = fields.items() - elif hasattr(fields, "__arrow_c_schema__"): + if hasattr(fields, "__arrow_c_schema__"): result = Schema._import_from_c_capsule(fields.__arrow_c_schema__()) if metadata is not None: result = result.with_metadata(metadata) return result + if isinstance(fields, Mapping): + fields = fields.items() + for item in fields: if isinstance(item, tuple): py_field = field(*item)