Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: allow JSON (de)serialization of ExtensionDtypes #44722

Merged
merged 7 commits into from
Dec 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/development/developer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ As an example of fully-formed metadata:
'numpy_type': 'int64',
'metadata': None}
],
'pandas_version': '0.20.0',
'pandas_version': '1.4.0',
'creator': {
'library': 'pyarrow',
'version': '0.13.0'
Expand Down
12 changes: 11 additions & 1 deletion doc/source/user_guide/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1903,6 +1903,7 @@ with optional parameters:
``index``; dict like {index -> {column -> value}}
``columns``; dict like {column -> {index -> value}}
``values``; just the values array
``table``; adhering to the JSON `Table Schema`_

* ``date_format`` : string, type of date conversion, 'epoch' for timestamp, 'iso' for ISO8601.
* ``double_precision`` : The number of decimal places to use when encoding floating point values, default 10.
Expand Down Expand Up @@ -2477,7 +2478,6 @@ A few notes on the generated table schema:
* For ``MultiIndex``, ``mi.names`` is used. If any level has no name,
then ``level_<i>`` is used.


``read_json`` also accepts ``orient='table'`` as an argument. This allows for
the preservation of metadata such as dtypes and index names in a
round-trippable manner.
Expand Down Expand Up @@ -2519,8 +2519,18 @@ indicate missing values and the subsequent read cannot distinguish the intent.

os.remove("test.json")

When using ``orient='table'`` along with user-defined ``ExtensionArray``,
the generated schema will contain an additional ``extDtype`` key in the respective
``fields`` element. This extra key is not standard but does enable JSON roundtrips
for extension types (e.g. ``read_json(df.to_json(orient="table"), orient="table")``).

The ``extDtype`` key carries the name of the extension, if you have properly registered
the ``ExtensionDtype``, pandas will use said name to perform a lookup into the registry
and re-convert the serialized data into your custom dtype.

.. _Table Schema: https://specs.frictionlessdata.io/table-schema/


HTML
----

Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ Other enhancements
- :meth:`UInt64Index.map` now retains ``dtype`` where possible (:issue:`44609`)
- :meth:`read_json` can now parse unsigned long long integers (:issue:`26068`)
- :meth:`DataFrame.take` now raises a ``TypeError`` when passed a scalar for the indexer (:issue:`42875`)
- :class:`ExtensionDtype` and :class:`ExtensionArray` are now (de)serialized when exporting a :class:`DataFrame` with :meth:`DataFrame.to_json` using ``orient='table'`` (:issue:`20612`, :issue:`44705`).
-


Expand Down
2 changes: 1 addition & 1 deletion pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2568,7 +2568,7 @@ def to_json(
"primaryKey": [
"index"
],
"pandas_version": "0.20.0"
"pandas_version": "1.4.0"
}},
"data": [
{{
Expand Down
4 changes: 1 addition & 3 deletions pandas/io/json/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@
loads = json.loads
dumps = json.dumps

TABLE_SCHEMA_VERSION = "0.20.0"
jreback marked this conversation as resolved.
Show resolved Hide resolved


# interface to/from
def to_json(
Expand Down Expand Up @@ -565,7 +563,7 @@ def read_json(
{{"name":"col 1","type":"string"}},\
{{"name":"col 2","type":"string"}}],\
"primaryKey":["index"],\
"pandas_version":"0.20.0"}},\
"pandas_version":"1.4.0"}},\
"data":[\
{{"index":"row 1","col 1":"a","col 2":"b"}},\
{{"index":"row 2","col 1":"c","col 2":"d"}}]\
Expand Down
14 changes: 12 additions & 2 deletions pandas/io/json/_table_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
JSONSerializable,
)

from pandas.core.dtypes.base import _registry as registry
from pandas.core.dtypes.common import (
is_bool_dtype,
is_categorical_dtype,
is_datetime64_dtype,
is_datetime64tz_dtype,
is_extension_array_dtype,
is_integer_dtype,
is_numeric_dtype,
is_period_dtype,
Expand All @@ -40,6 +42,8 @@

loads = json.loads

TABLE_SCHEMA_VERSION = "1.4.0"


def as_json_table_type(x: DtypeObj) -> str:
"""
Expand Down Expand Up @@ -83,6 +87,8 @@ def as_json_table_type(x: DtypeObj) -> str:
return "duration"
elif is_categorical_dtype(x):
return "any"
elif is_extension_array_dtype(x):
return "any"
elif is_string_dtype(x):
return "string"
else:
Expand Down Expand Up @@ -130,6 +136,8 @@ def convert_pandas_type_to_json_field(arr):
field["freq"] = dtype.freq.freqstr
elif is_datetime64tz_dtype(dtype):
field["tz"] = dtype.tz.zone
elif is_extension_array_dtype(dtype):
field["extDtype"] = dtype.name
return field


Expand Down Expand Up @@ -195,6 +203,8 @@ def convert_json_field_to_pandas_type(field):
return CategoricalDtype(
categories=field["constraints"]["enum"], ordered=field["ordered"]
)
elif "extDtype" in field:
return registry.find(field["extDtype"])
else:
return "object"

Expand Down Expand Up @@ -253,7 +263,7 @@ def build_table_schema(
{'name': 'B', 'type': 'string'}, \
{'name': 'C', 'type': 'datetime'}], \
'primaryKey': ['idx'], \
'pandas_version': '0.20.0'}
'pandas_version': '1.4.0'}
"""
if index is True:
data = set_default_names(data)
Expand Down Expand Up @@ -287,7 +297,7 @@ def build_table_schema(
schema["primaryKey"] = primary_key

if version:
schema["pandas_version"] = "0.20.0"
schema["pandas_version"] = TABLE_SCHEMA_VERSION
return schema


Expand Down
6 changes: 6 additions & 0 deletions pandas/tests/extension/date/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pandas.tests.extension.date.array import (
DateArray,
DateDtype,
)

__all__ = ["DateArray", "DateDtype"]
180 changes: 180 additions & 0 deletions pandas/tests/extension/date/array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import datetime as dt
from typing import (
Any,
Optional,
Sequence,
Tuple,
Union,
cast,
)

import numpy as np

from pandas._typing import (
Dtype,
PositionalIndexer,
)

from pandas.core.dtypes.dtypes import register_extension_dtype

from pandas.api.extensions import (
ExtensionArray,
ExtensionDtype,
)
from pandas.api.types import pandas_dtype


@register_extension_dtype
class DateDtype(ExtensionDtype):
@property
def type(self):
return dt.date

@property
def name(self):
return "DateDtype"

@classmethod
def construct_from_string(cls, string: str):
if not isinstance(string, str):
raise TypeError(
f"'construct_from_string' expects a string, got {type(string)}"
)

if string == cls.__name__:
return cls()
else:
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")

@classmethod
def construct_array_type(cls):
return DateArray

@property
def na_value(self):
return dt.date.min

def __repr__(self) -> str:
return self.name


class DateArray(ExtensionArray):
def __init__(
self,
dates: Union[
dt.date,
Sequence[dt.date],
Tuple[np.ndarray, np.ndarray, np.ndarray],
np.ndarray,
],
) -> None:
if isinstance(dates, dt.date):
self._year = np.array([dates.year])
self._month = np.array([dates.month])
self._day = np.array([dates.year])
return

ldates = len(dates)
if isinstance(dates, list):
# pre-allocate the arrays since we know the size before hand
self._year = np.zeros(ldates, dtype=np.uint16) # 65535 (0, 9999)
self._month = np.zeros(ldates, dtype=np.uint8) # 255 (1, 31)
self._day = np.zeros(ldates, dtype=np.uint8) # 255 (1, 12)
# populate them
for i, (y, m, d) in enumerate(
map(lambda date: (date.year, date.month, date.day), dates)
):
self._year[i] = y
self._month[i] = m
self._day[i] = d

elif isinstance(dates, tuple):
# only support triples
if ldates != 3:
raise ValueError("only triples are valid")
# check if all elements have the same type
if any(map(lambda x: not isinstance(x, np.ndarray), dates)):
raise TypeError("invalid type")
ly, lm, ld = (len(cast(np.ndarray, d)) for d in dates)
if not ly == lm == ld:
raise ValueError(
f"tuple members must have the same length: {(ly, lm, ld)}"
)
self._year = dates[0].astype(np.uint16)
self._month = dates[1].astype(np.uint8)
self._day = dates[2].astype(np.uint8)

elif isinstance(dates, np.ndarray) and dates.dtype == "U10":
self._year = np.zeros(ldates, dtype=np.uint16) # 65535 (0, 9999)
self._month = np.zeros(ldates, dtype=np.uint8) # 255 (1, 31)
self._day = np.zeros(ldates, dtype=np.uint8) # 255 (1, 12)

for (i,), (y, m, d) in np.ndenumerate(np.char.split(dates, sep="-")):
self._year[i] = int(y)
self._month[i] = int(m)
self._day[i] = int(d)

else:
raise TypeError(f"{type(dates)} is not supported")

@property
def dtype(self) -> ExtensionDtype:
return DateDtype()

def astype(self, dtype, copy=True):
dtype = pandas_dtype(dtype)

if isinstance(dtype, DateDtype):
data = self.copy() if copy else self
else:
data = self.to_numpy(dtype=dtype, copy=copy, na_value=dt.date.min)

return data

@property
def nbytes(self) -> int:
return self._year.nbytes + self._month.nbytes + self._day.nbytes

def __len__(self) -> int:
return len(self._year) # all 3 arrays are enforced to have the same length

def __getitem__(self, item: PositionalIndexer):
if isinstance(item, int):
return dt.date(self._year[item], self._month[item], self._day[item])
else:
raise NotImplementedError("only ints are supported as indexes")

def __setitem__(self, key: Union[int, slice, np.ndarray], value: Any):
if not isinstance(key, int):
raise NotImplementedError("only ints are supported as indexes")

if not isinstance(value, dt.date):
raise TypeError("you can only set datetime.date types")

self._year[key] = value.year
self._month[key] = value.month
self._day[key] = value.day

def __repr__(self) -> str:
return f"DateArray{list(zip(self._year, self._month, self._day))}"

def copy(self) -> "DateArray":
return DateArray((self._year.copy(), self._month.copy(), self._day.copy()))

def isna(self) -> np.ndarray:
return np.logical_and(
np.logical_and(
self._year == dt.date.min.year, self._month == dt.date.min.month
),
self._day == dt.date.min.day,
)

@classmethod
def _from_sequence(cls, scalars, *, dtype: Optional[Dtype] = None, copy=False):
if isinstance(scalars, dt.date):
pass
elif isinstance(scalars, DateArray):
pass
elif isinstance(scalars, np.ndarray):
scalars = scalars.astype("U10") # 10 chars for yyyy-mm-dd
return DateArray(scalars)
7 changes: 5 additions & 2 deletions pandas/tests/extension/decimal/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray):

def __init__(self, values, dtype=None, copy=False, context=None):
for i, val in enumerate(values):
if is_float(val) and np.isnan(val):
values[i] = DecimalDtype.na_value
if is_float(val):
jmg-duarte marked this conversation as resolved.
Show resolved Hide resolved
if np.isnan(val):
values[i] = DecimalDtype.na_value
else:
values[i] = DecimalDtype.type(val)
elif not isinstance(val, decimal.Decimal):
raise TypeError("All values must be of type " + str(decimal.Decimal))
values = np.asarray(values, dtype=object)
Expand Down
Loading