Skip to content

Commit

Permalink
feat: raise Exception if user mixes Form styles
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Oct 6, 2022
1 parent 198b329 commit c4db00a
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 38 deletions.
116 changes: 78 additions & 38 deletions src/awkward/forms/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,24 @@
from typing import Any, Mapping

import awkward as ak
from awkward import _errors

np = ak.nplikes.NumpyMetadata.instance()


def from_dict(input: dict) -> Form:
return _from_dict(input, is_legacy_record=None)


def _from_dict(input: dict, is_legacy_record: bool | None) -> Form:
"""
Args:
input: form dictionary
is_legacy_record: whether to expect record forms in the legacy style
Returns:
"""
if input is None:
return None

Expand All @@ -33,7 +46,7 @@ def from_dict(input: dict) -> Form:

elif input["class"] == "RegularArray":
return ak.forms.RegularForm(
content=from_dict(input["content"]),
content=_from_dict(input["content"], is_legacy_record),
size=input["size"],
has_identifier=has_identifier,
parameters=parameters,
Expand All @@ -44,7 +57,7 @@ def from_dict(input: dict) -> Form:
return ak.forms.ListForm(
starts=input["starts"],
stops=input["stops"],
content=from_dict(input["content"]),
content=_from_dict(input["content"], is_legacy_record),
has_identifier=has_identifier,
parameters=parameters,
form_key=form_key,
Expand All @@ -58,28 +71,53 @@ def from_dict(input: dict) -> Form:
):
return ak.forms.ListOffsetForm(
offsets=input["offsets"],
content=from_dict(input["content"]),
content=_from_dict(input["content"], is_legacy_record),
has_identifier=has_identifier,
parameters=parameters,
form_key=form_key,
)

elif input["class"] == "RecordArray":
# Keep track of the current style record before
# we read this one
last_is_legacy_record = is_legacy_record

# New serialisation
if "fields" in input:
contents = [from_dict(content) for content in input["contents"]]
is_legacy_record = False
contents = [
_from_dict(content, is_legacy_record) for content in input["contents"]
]
fields = input["fields"]
# Old style record
elif isinstance(input["contents"], dict):
contents = []
fields = []
for key, content in input["contents"].items():
contents.append(from_dict(content))
fields.append(key)
# Old style tuple
else:
contents = [from_dict(content) for content in input["contents"]]
fields = None
is_legacy_record = True
if isinstance(input["contents"], dict):
contents = []
fields = []
for key, content in input["contents"].items():
contents.append(from_dict(content))
fields.append(key)
# Old style tuple
else:
contents = [
_from_dict(content, is_legacy_record)
for content in input["contents"]
]
fields = None

# If we have read two records of different styles, we must warn the user
if (
last_is_legacy_record is not None
and is_legacy_record != last_is_legacy_record
):
raise _errors.wrap_error(
ValueError(
"encountered an old-style RecordArray form after a new-style RecordArray form. "
"Forms should not mix RecordArray styles"
)
)

return ak.forms.RecordForm(
contents=contents,
fields=fields,
Expand All @@ -96,7 +134,7 @@ def from_dict(input: dict) -> Form:
):
return ak.forms.IndexedForm(
index=input["index"],
content=from_dict(input["content"]),
content=_from_dict(input["content"], is_legacy_record),
has_identifier=has_identifier,
parameters=parameters,
form_key=form_key,
Expand All @@ -109,7 +147,7 @@ def from_dict(input: dict) -> Form:
):
return ak.forms.IndexedOptionForm(
index=input["index"],
content=from_dict(input["content"]),
content=_from_dict(input["content"], is_legacy_record),
has_identifier=has_identifier,
parameters=parameters,
form_key=form_key,
Expand All @@ -118,7 +156,7 @@ def from_dict(input: dict) -> Form:
elif input["class"] == "ByteMaskedArray":
return ak.forms.ByteMaskedForm(
mask=input["mask"],
content=from_dict(input["content"]),
content=_from_dict(input["content"], is_legacy_record),
valid_when=input["valid_when"],
has_identifier=has_identifier,
parameters=parameters,
Expand All @@ -128,7 +166,7 @@ def from_dict(input: dict) -> Form:
elif input["class"] == "BitMaskedArray":
return ak.forms.BitMaskedForm(
mask=input["mask"],
content=from_dict(input["content"]),
content=_from_dict(input["content"], is_legacy_record),
valid_when=input["valid_when"],
lsb_order=input["lsb_order"],
has_identifier=has_identifier,
Expand All @@ -138,7 +176,7 @@ def from_dict(input: dict) -> Form:

elif input["class"] == "UnmaskedArray":
return ak.forms.UnmaskedForm(
content=from_dict(input["content"]),
content=_from_dict(input["content"], is_legacy_record),
has_identifier=has_identifier,
parameters=parameters,
form_key=form_key,
Expand All @@ -153,19 +191,21 @@ def from_dict(input: dict) -> Form:
return ak.forms.UnionForm(
tags=input["tags"],
index=input["index"],
contents=[from_dict(content) for content in input["contents"]],
contents=[
_from_dict(content, is_legacy_record) for content in input["contents"]
],
has_identifier=has_identifier,
parameters=parameters,
form_key=form_key,
)

elif input["class"] == "VirtualArray":
raise ak._errors.wrap_error(
raise _errors.wrap_error(
ValueError("Awkward 1.x VirtualArrays are not supported")
)

else:
raise ak._errors.wrap_error(
raise _errors.wrap_error(
ValueError(
"Input class: {} was not recognised".format(repr(input["class"]))
)
Expand Down Expand Up @@ -271,23 +311,23 @@ class Form:

def _init(self, has_identifier, parameters, form_key):
if not isinstance(has_identifier, bool):
raise ak._errors.wrap_error(
raise _errors.wrap_error(
TypeError(
"{} 'has_identifier' must be of type bool, not {}".format(
type(self).__name__, repr(has_identifier)
)
)
)
if parameters is not None and not isinstance(parameters, dict):
raise ak._errors.wrap_error(
raise _errors.wrap_error(
TypeError(
"{} 'parameters' must be of type dict or None, not {}".format(
type(self).__name__, repr(parameters)
)
)
)
if form_key is not None and not ak._util.isstr(form_key):
raise ak._errors.wrap_error(
raise _errors.wrap_error(
TypeError(
"{} 'form_key' must be of type string or None, not {}".format(
type(self).__name__, repr(form_key)
Expand All @@ -312,7 +352,7 @@ def parameters(self):
@property
def is_identity_like(self):
"""Return True if the content or its non-list descendents are an identity"""
raise ak._errors.wrap_error(NotImplementedError)
raise _errors.wrap_error(NotImplementedError)

def parameter(self, key):
if self._parameters is None:
Expand All @@ -321,31 +361,31 @@ def parameter(self, key):
return self._parameters.get(key)

def purelist_parameter(self, key):
raise ak._errors.wrap_error(NotImplementedError)
raise _errors.wrap_error(NotImplementedError)

@property
def purelist_isregular(self):
raise ak._errors.wrap_error(NotImplementedError)
raise _errors.wrap_error(NotImplementedError)

@property
def purelist_depth(self):
raise ak._errors.wrap_error(NotImplementedError)
raise _errors.wrap_error(NotImplementedError)

@property
def minmax_depth(self):
raise ak._errors.wrap_error(NotImplementedError)
raise _errors.wrap_error(NotImplementedError)

@property
def branch_depth(self):
raise ak._errors.wrap_error(NotImplementedError)
raise _errors.wrap_error(NotImplementedError)

@property
def fields(self):
raise ak._errors.wrap_error(NotImplementedError)
raise _errors.wrap_error(NotImplementedError)

@property
def is_tuple(self):
raise ak._errors.wrap_error(NotImplementedError)
raise _errors.wrap_error(NotImplementedError)

@property
def form_key(self):
Expand Down Expand Up @@ -402,7 +442,7 @@ def select_columns(self, specifier, expand_braces=True):

for item in specifier:
if not ak._util.isstr(item):
raise ak._errors.wrap_error(
raise _errors.wrap_error(
TypeError("a column-selection specifier must be a list of strings")
)

Expand All @@ -423,16 +463,16 @@ def column_types(self):
return self._column_types()

def _columns(self, path, output, list_indicator):
raise ak._errors.wrap_error(NotImplementedError)
raise _errors.wrap_error(NotImplementedError)

def _select_columns(self, index, specifier, matches, output):
raise ak._errors.wrap_error(NotImplementedError)
raise _errors.wrap_error(NotImplementedError)

def _column_types(self):
raise ak._errors.wrap_error(NotImplementedError)
raise _errors.wrap_error(NotImplementedError)

def _to_dict_part(self, verbose, toplevel):
raise ak._errors._errors(NotImplementedError)
raise _errors._errors(NotImplementedError)

def _type(self, typestrs):
raise ak._errors.wrap_error(NotImplementedError)
raise _errors.wrap_error(NotImplementedError)
40 changes: 40 additions & 0 deletions tests/test_1766-record-form-fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,43 @@ def test_old_style_tuple():
assert array.is_tuple
assert array.fields == ["0", "1"]
assert array.to_list() == [(1, 2)]


def test_new_style_old_style_record():
form = {
"class": "RecordArray",
"fields": ["x"],
"contents": [
{
"class": "RecordArray",
"contents": {
"y": {
"class": "ListOffsetArray",
"offsets": "i64",
"content": {
"class": "NumpyArray",
"primitive": "int64",
"inner_shape": [],
"has_identifier": False,
"parameters": {},
"form_key": "node3",
},
"has_identifier": False,
"parameters": {},
"form_key": "node2",
}
},
"has_identifier": False,
"parameters": {},
"form_key": "node1",
}
],
"has_identifier": False,
"parameters": {},
"form_key": "node0",
}

with pytest.raises(ValueError, match=".*Forms should not mix RecordArray styles.*"):
ak.from_buffers(
form, 1, {"node2-offsets": np.array([0, 1]), "node3-data": np.array([0])}
)

0 comments on commit c4db00a

Please sign in to comment.