From 94d96841249d1a6a9ced79131783bb12c730fd73 Mon Sep 17 00:00:00 2001 From: Andres Rios Tascon Date: Tue, 5 Nov 2024 14:05:27 -0500 Subject: [PATCH] chore: miscellaneous RNTuple improvements (#1250) * Fixed __len__ method * Added a few more useful methods * Use the right number in arrays method * Updated to match spec and did some cleanup * Fixed order of extra type information * Extract column summary flags * style: pre-commit fixes * Fixed conflict resolution * Fixed test * Switched to using enums * Fixed RNTuple anchor * Updated locator types * Removed UserMetadata envelope * Started implementing new real32 types * Updated sharded cluster to match spec * Removed user metadata from footer * Fixed ClusterSummaryReader * Fix cascadentuple * Introduced RNTupleField class * Added test for #1285 * Fixed test * Fix test (attempt 2) * Finalized first version of RNTupleField * Added tests for RNTupleField * Implemented iterate method --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/uproot/const.py | 58 +- src/uproot/models/RNTuple.py | 560 ++++++++++++++---- src/uproot/writing/_cascadentuple.py | 10 +- tests-wasm/test_1272_basic_functionality.py | 2 +- tests/test_0013_rntuple_anchor.py | 2 +- tests/test_1191_rntuple_fixes.py | 2 +- tests/test_1250_rntuple_improvements.py | 85 +++ ...1285_rntuple_multicluster_concatenation.py | 20 + 8 files changed, 620 insertions(+), 119 deletions(-) create mode 100644 tests/test_1250_rntuple_improvements.py create mode 100644 tests/test_1285_rntuple_multicluster_concatenation.py diff --git a/src/uproot/const.py b/src/uproot/const.py index 8613076e0..4f6248822 100644 --- a/src/uproot/const.py +++ b/src/uproot/const.py @@ -5,7 +5,7 @@ """ from __future__ import annotations -import struct +from enum import IntEnum import numpy @@ -118,8 +118,6 @@ kStreamedMemberWise = numpy.uint16(1 << 14) ############ RNTuple https://github.com/root-project/root/blob/master/tree/ntuple/v7/doc/specifications.md -_rntuple_frame_format = struct.Struct("HHHHQQQQQQQ") - -# https://github.com/root-project/root/blob/aa513463b0b512517370cb91cca025e53a8b13a2/tree/ntuple/v7/doc/specifications.md#envelopes +# https://github.com/root-project/root/blob/8635b1bc0da59623777c9fda3661a19363964915/tree/ntuple/v7/doc/specifications.md#feature-flags _rntuple_feature_flag_format = struct.Struct("> 1 ^ -(n & 1) @@ -47,6 +61,127 @@ def _envelop_header(chunk, cursor, context): return {"env_type_id": env_type_id, "env_length": env_length} +def _arrays( + in_ntuple, + filter_name="*", + filter_typename=None, + entry_start=0, + entry_stop=None, + decompression_executor=None, + array_cache=None, +): + ak = uproot.extras.awkward() + + entry_stop = entry_stop or in_ntuple.ntuple.num_entries + + clusters = in_ntuple.ntuple.cluster_summaries + cluster_starts = numpy.array([c.num_first_entry for c in clusters]) + + start_cluster_idx = ( + numpy.searchsorted(cluster_starts, entry_start, side="right") - 1 + ) + stop_cluster_idx = numpy.searchsorted(cluster_starts, entry_stop, side="right") + cluster_num_entries = numpy.sum( + [c.num_entries for c in clusters[start_cluster_idx:stop_cluster_idx]] + ) + + form = in_ntuple.to_akform().select_columns( + filter_name, prune_unions_and_records=False + ) + # only read columns mentioned in the awkward form + target_cols = [] + container_dict = {} + _recursive_find(form, target_cols) + for key in target_cols: + if "column" in key and "union" not in key: + key_nr = int(key.split("-")[1]) + dtype_byte = in_ntuple.ntuple.column_records[key_nr].type + + content = in_ntuple.ntuple.read_col_pages( + key_nr, + range(start_cluster_idx, stop_cluster_idx), + dtype_byte=dtype_byte, + pad_missing_element=True, + ) + if "cardinality" in key: + content = numpy.diff(content) + if dtype_byte == uproot.const.rntuple_col_type_to_num_dict["switch"]: + kindex, tags = _split_switch_bits(content) + # Find invalid variants and adjust buffers accordingly + invalid = numpy.flatnonzero(tags == -1) + if len(invalid) > 0: + kindex = numpy.delete(kindex, invalid) + tags = numpy.delete(tags, invalid) + invalid -= numpy.arange(len(invalid)) + optional_index = numpy.insert( + numpy.arange(len(kindex), dtype=numpy.int64), invalid, -1 + ) + else: + optional_index = numpy.arange(len(kindex), dtype=numpy.int64) + container_dict[f"{key}-index"] = optional_index + container_dict[f"{key}-union-index"] = kindex + container_dict[f"{key}-union-tags"] = tags + else: + # don't distinguish data and offsets + container_dict[f"{key}-data"] = content + container_dict[f"{key}-offsets"] = content + cluster_offset = cluster_starts[start_cluster_idx] + entry_start -= cluster_offset + entry_stop -= cluster_offset + return ak.from_buffers( + form, cluster_num_entries, container_dict, allow_noncanonical_form=True + )[entry_start:entry_stop] + + +def _num_entries_for(in_ntuple, target_num_bytes, filter_name): + # TODO: part of this is also done in _arrays, so we should refactor this + # TODO: there might be a better way to estimate the number of entries + entry_stop = in_ntuple.ntuple.num_entries + + clusters = in_ntuple.ntuple.cluster_summaries + cluster_starts = numpy.array([c.num_first_entry for c in clusters]) + + start_cluster_idx = numpy.searchsorted(cluster_starts, 0, side="right") - 1 + stop_cluster_idx = numpy.searchsorted(cluster_starts, entry_stop, side="right") + + form = in_ntuple.to_akform().select_columns( + filter_name, prune_unions_and_records=False + ) + target_cols = [] + _recursive_find(form, target_cols) + + total_bytes = 0 + for key in target_cols: + if "column" in key and "union" not in key: + key_nr = int(key.split("-")[1]) + for cluster in range(start_cluster_idx, stop_cluster_idx): + pages = in_ntuple.ntuple.page_list_envelopes.pagelinklist[cluster][ + key_nr + ] + total_bytes += sum(page.locator.num_bytes for page in pages) + + total_entries = entry_stop + if total_bytes == 0: + num_entries = 0 + else: + num_entries = int(round(target_num_bytes * total_entries / total_bytes)) + if num_entries <= 0: + return 1 + else: + return num_entries + + +def _regularize_step_size(in_ntuple, step_size, filter_name): + if uproot._util.isint(step_size): + return step_size + target_num_bytes = uproot._util.memory_size( + step_size, + "number of entries or memory size string with units " + f"(such as '100 MB') required, not {step_size!r}", + ) + return _num_entries_for(in_ntuple, target_num_bytes, filter_name) + + class Model_ROOT_3a3a_Experimental_3a3a_RNTuple(uproot.model.Model): """ A versionless :doc:`uproot.model.Model` for ``ROOT::Experimental::RNTuple``. @@ -76,6 +211,24 @@ def keys( else: return self._keys + @property + def _key_indices(self): + indices = [] + field_records = self.field_records + for i, fr in enumerate(field_records): + if fr.parent_field_id == i and fr.type_name != "": + indices.append(i) + return indices + + @property + def _key_to_index(self): + d = {} + field_records = self.field_records + for i, fr in enumerate(field_records): + if fr.parent_field_id == i and fr.type_name != "": + d[fr.field_name] = i + return d + def read_members(self, chunk, cursor, context, file): if uproot._awkwardforth.get_forth_obj(context) is not None: raise uproot.interpretation.objects.CannotBeForth() @@ -96,9 +249,11 @@ def read_members(self, chunk, cursor, context, file): self._members["fSeekFooter"], self._members["fNBytesFooter"], self._members["fLenFooter"], - self._members["fChecksum"], + self._members["fMaxKeySize"], ) = cursor.fields(chunk, _rntuple_anchor_format, context) + # TODO: There is a checksum afterwards that we can use to verify the integrity of the members. + self._header_chunk_ready = False self._footer_chunk_ready = False self._header, self._footer = None, None @@ -110,9 +265,13 @@ def read_members(self, chunk, cursor, context, file): self._alias_columns_dict_ = None self._related_ids_ = None self._column_records_dict_ = None + self._num_entries = None + self._length = None self._page_list_envelopes = [] + self.ntuple = self + def _prepare_header_chunk(self): context = {} seek, nbytes = self._members["fSeekHeader"], self._members["fNBytesHeader"] @@ -249,14 +408,62 @@ def footer(self): def cluster_summaries(self): return self.page_list_envelopes.cluster_summaries - # FIXME @property - def _length(self): - return sum(x.num_entries for x in self.cluster_summaries) + def num_entries(self): + if self._num_entries is None: + self._num_entries = sum(x.num_entries for x in self.cluster_summaries) + return self._num_entries def __len__(self): + if self._length is None: + self._length = len(self.keys()) return self._length + def __repr__(self): + if len(self) == 0: + return f"" + else: + return ( + f"" + ) + + def __getitem__(self, where): + # original_where = where + + if uproot._util.isint(where): + index = self._key_indices[where] + elif isinstance(where, str): + where = uproot._util.ensure_str(where) + index = self._key_to_index[where] + else: + raise TypeError(f"where must be an integer or a string, not {where!r}") + + # TODO: Implement path support + + return RNTupleField(index, self) + + @property + def name(self): + """ + Name of the ``RNTuple``. + """ + return self.parent.fName + + @property + def object_path(self): + """ + Object path of the ``RNTuple``. + """ + return self.parent.object_path + + @property + def cache_key(self): + """ + String that uniquely specifies this ``RNTuple`` in its path, to use as + part of object and array cache keys. + """ + return f"{self.parent.cache_key}{self.name};{self.parent.fCycle}" + def read_locator(self, loc, uncomp_size, context): cursor = uproot.source.cursor.Cursor(loc.offset) chunk = self.file.source.chunk(loc.offset, loc.offset + loc.num_bytes) @@ -291,9 +498,9 @@ def base_col_form(self, cr, col_id, parameters=None, cardinality=False): form_key = f"column-{col_id}" + ("-cardinality" if cardinality else "") dtype_byte = cr.type - if dtype_byte == uproot.const.rntuple_role_union: + if dtype_byte == uproot.const.rntuple_col_type_to_num_dict["switch"]: return form_key - elif dtype_byte > uproot.const.rntuple_role_struct: + elif dtype_byte > uproot.const.rntuple_col_type_to_num_dict["switch"]: dt_str = uproot.const.rntuple_col_num_to_dtype_dict[dtype_byte] if dt_str == "bit": dt_str = "bool" @@ -348,7 +555,7 @@ def field_form(self, this_id, seen): seen.add(this_id) structural_role = this_record.struct_role if ( - structural_role == uproot.const.rntuple_role_leaf + structural_role == uproot.const.RNTupleFieldRole.LEAF and this_record.repetition == 0 ): # deal with std::atomic @@ -363,7 +570,7 @@ def field_form(self, this_id, seen): # base case of recursion # n.b. the split may happen in column return self.col_form(this_id) - elif structural_role == uproot.const.rntuple_role_leaf: + elif structural_role == uproot.const.RNTupleFieldRole.LEAF: if this_id in self._related_ids: # std::array has only one subfield child_id = self._related_ids[this_id][0] @@ -373,7 +580,7 @@ def field_form(self, this_id, seen): inner = self.col_form(this_id) keyname = f"RegularForm-{this_id}" return ak.forms.RegularForm(inner, this_record.repetition, form_key=keyname) - elif structural_role == uproot.const.rntuple_role_vector: + elif structural_role == uproot.const.RNTupleFieldRole.VECTOR: if this_id not in self._related_ids or len(self._related_ids[this_id]) != 1: keyname = f"vector-{this_id}" newids = self._related_ids.get(this_id, []) @@ -397,7 +604,7 @@ def field_form(self, this_id, seen): child_id = self._related_ids[this_id][0] inner = self.field_form(child_id, seen) return ak.forms.ListOffsetForm("i64", inner, form_key=keyname) - elif structural_role == uproot.const.rntuple_role_struct: + elif structural_role == uproot.const.RNTupleFieldRole.STRUCT: newids = [] if this_id in self._related_ids: newids = self._related_ids[this_id] @@ -405,7 +612,7 @@ def field_form(self, this_id, seen): recordlist = [self.field_form(i, seen) for i in newids] namelist = [field_records[i].field_name for i in newids] return ak.forms.RecordForm(recordlist, namelist, form_key="whatever") - elif structural_role == uproot.const.rntuple_role_union: + elif structural_role == uproot.const.RNTupleFieldRole.UNION: keyname = self.col_form(this_id) newids = [] if this_id in self._related_ids: @@ -415,6 +622,10 @@ def field_form(self, this_id, seen): "i8", "i64", recordlist, form_key=keyname + "-union" ) return ak.forms.IndexedOptionForm("i64", inner, form_key=keyname) + elif structural_role == uproot.const.RNTupleFieldRole.UNSPLIT: + raise NotImplementedError( + f"Unsplit fields are not supported. {this_record}" + ) else: # everything should recurse above this branch raise AssertionError("this should be unreachable") @@ -493,7 +704,9 @@ def read_pagedesc(self, destination, desc, dtype_str, dtype, nbits, split): # needed to chop off extra bits incase we used `unpackbits` destination[:] = content[:num_elements] - def read_col_pages(self, ncol, cluster_range, dtype_byte, pad_missing_ele=False): + def read_col_pages( + self, ncol, cluster_range, dtype_byte, pad_missing_element=False + ): arrays = [self.read_col_page(ncol, i) for i in cluster_range] # Check if column stores offset values for jagged arrays (splitindex64) (applies to cardinality cols too): @@ -515,9 +728,9 @@ def read_col_pages(self, ncol, cluster_range, dtype_byte, pad_missing_ele=False) res = numpy.concatenate(arrays, axis=0) - if pad_missing_ele: - first_ele_index = self.column_records[ncol].first_ele_index - res = numpy.pad(res, (first_ele_index, 0)) + if pad_missing_element: + first_element_index = self.column_records[ncol].first_element_index + res = numpy.pad(res, (first_element_index, 0)) return res def read_col_page(self, ncol, cluster_i): @@ -554,7 +767,7 @@ def read_col_page(self, ncol, cluster_i): if index: res = numpy.insert(res, 0, 0) # for offsets if zigzag: - res = from_zigzag(res) + res = _from_zigzag(res) elif delta: res = numpy.cumsum(res) return res @@ -568,67 +781,22 @@ def arrays( decompression_executor=None, array_cache=None, ): - ak = uproot.extras.awkward() - - entry_stop = entry_stop or self._length - - clusters = self.cluster_summaries - cluster_starts = numpy.array([c.num_first_entry for c in clusters]) - - start_cluster_idx = ( - numpy.searchsorted(cluster_starts, entry_start, side="right") - 1 - ) - stop_cluster_idx = numpy.searchsorted(cluster_starts, entry_stop, side="right") - cluster_num_entries = numpy.sum( - [c.num_entries for c in clusters[start_cluster_idx:stop_cluster_idx]] + return _arrays( + self, + filter_name=filter_name, + filter_typename=filter_typename, + entry_start=entry_start, + entry_stop=entry_stop, + decompression_executor=decompression_executor, + array_cache=array_cache, ) - form = self.to_akform().select_columns( - filter_name, prune_unions_and_records=False - ) - # only read columns mentioned in the awkward form - target_cols = [] - container_dict = {} - _recursive_find(form, target_cols) - for key in target_cols: - if "column" in key and "union" not in key: - key_nr = int(key.split("-")[1]) - dtype_byte = self.column_records[key_nr].type - - content = self.read_col_pages( - key_nr, - range(start_cluster_idx, stop_cluster_idx), - dtype_byte=dtype_byte, - pad_missing_ele=True, - ) - if "cardinality" in key: - content = numpy.diff(content) - if dtype_byte == uproot.const.rntuple_col_type_to_num_dict["switch"]: - kindex, tags = _split_switch_bits(content) - # Find invalid variants and adjust buffers accordingly - invalid = numpy.flatnonzero(tags == -1) - if len(invalid) > 0: - kindex = numpy.delete(kindex, invalid) - tags = numpy.delete(tags, invalid) - invalid -= numpy.arange(len(invalid)) - optional_index = numpy.insert( - numpy.arange(len(kindex), dtype=numpy.int64), invalid, -1 - ) - else: - optional_index = numpy.arange(len(kindex), dtype=numpy.int64) - container_dict[f"{key}-index"] = optional_index - container_dict[f"{key}-union-index"] = kindex - container_dict[f"{key}-union-tags"] = tags - else: - # don't distinguish data and offsets - container_dict[f"{key}-data"] = content - container_dict[f"{key}-offsets"] = content - cluster_offset = cluster_starts[start_cluster_idx] - entry_start -= cluster_offset - entry_stop -= cluster_offset - return ak.from_buffers( - form, cluster_num_entries, container_dict, allow_noncanonical_form=True - )[entry_start:entry_stop] + def iterate(self, filter_name="*", *args, step_size="100 MB", **kwargs): + step_size = _regularize_step_size(self, step_size, filter_name) + for start in range(0, self.num_entries, step_size): + yield self.arrays( + *args, entry_start=start, entry_stop=start + step_size, **kwargs + ) # Supporting function and classes @@ -672,6 +840,9 @@ def __init__(self): def read(self, chunk, cursor, context): out = MetaData(type(self).__name__) out.env_header = _envelop_header(chunk, cursor, context) + assert ( + out.env_header["env_type_id"] == uproot.const.RNTupleEnvelopeType.PAGELIST + ), f"env_type_id={out.env_header['env_type_id']}" out.header_checksum = cursor.field(chunk, _rntuple_checksum_format, context) out.cluster_summaries = self.list_cluster_summaries.read(chunk, cursor, context) out.pagelinklist = self.nested_page_locations.read(chunk, cursor, context) @@ -682,9 +853,23 @@ def read(self, chunk, cursor, context): class LocatorReader: def read(self, chunk, cursor, context): out = MetaData("Locator") - out.num_bytes, out.offset = cursor.fields( - chunk, _rntuple_locator_format, context - ) + out.num_bytes = cursor.field(chunk, _rntuple_locator_size_format, context) + if out.num_bytes < 0: + out.type = -out.num_bytes >> 24 + if out.type == uproot.const.RNTupleLocatorType.LARGE: + out.num_bytes = cursor.field( + chunk, _rntuple_large_locator_size_format, context + ) + out.offset = cursor.field( + chunk, _rntuple_locator_offset_format, context + ) + elif out.type == uproot.const.RNTupleLocatorType.DAOS: + raise NotImplementedError("DAOS locators are not supported.") + else: + raise NotImplementedError(f"Unknown locator type: {out.type}") + else: + out.type = uproot.const.RNTupleLocatorType.STANDARD + out.offset = cursor.field(chunk, _rntuple_locator_offset_format, context) return out @@ -725,7 +910,7 @@ def __init__(self, payload): def read(self, chunk, cursor, context): local_cursor = cursor.copy() - num_bytes = local_cursor.field(chunk, _rntuple_record_size_format, context) + num_bytes = local_cursor.field(chunk, _rntuple_frame_size_format, context) assert num_bytes >= 0, f"num_bytes={num_bytes}" cursor.skip(num_bytes) return self.payload.read(chunk, local_cursor, context) @@ -737,10 +922,9 @@ def __init__(self, payload): def read(self, chunk, cursor, context): local_cursor = cursor.copy() - num_bytes, num_items = local_cursor.fields( - chunk, _rntuple_frame_header_format, context - ) + num_bytes = local_cursor.field(chunk, _rntuple_frame_size_format, context) assert num_bytes < 0, f"num_bytes={num_bytes}" + num_items = local_cursor.field(chunk, _rntuple_frame_num_items_format, context) cursor.skip(-num_bytes) return [ self.payload.read(chunk, local_cursor, context) for _ in range(num_items) @@ -758,10 +942,24 @@ def read(self, chunk, cursor, context): out.struct_role, out.flags, ) = cursor.fields(chunk, _rntuple_field_description_format, context) - if out.flags == 0x0001: + if out.flags == uproot.const.RNTupleFieldFlag.REPETITIVE: out.repetition = cursor.field(chunk, _rntuple_repetition_format, context) + out.source_field_id = None + out.checksum = None + elif out.flags == uproot.const.RNTupleFieldFlag.PROJECTED: + out.repetition = 0 + out.source_field_id = cursor.field( + chunk, _rntuple_source_field_id_format, context + ) + out.checksum = None + elif out.flags == uproot.const.RNTupleFieldFlag.CHECKSUM: + out.repetition = 0 + out.source_field_id = None + out.checksum = cursor.field(chunk, _rntuple_checksum_format, context) else: out.repetition = 0 + out.source_field_id = None + out.checksum = None out.field_name, out.type_name, out.type_alias, out.field_desc = ( cursor.rntuple_string(chunk, context) for _ in range(4) ) @@ -772,15 +970,21 @@ def read(self, chunk, cursor, context): class ColumnRecordReader: def read(self, chunk, cursor, context): out = MetaData("ColumnRecordFrame") - out.type, out.nbits, out.field_id, out.flags = cursor.fields( + out.type, out.nbits, out.field_id, out.flags, out.repr_idx = cursor.fields( chunk, _rntuple_column_record_format, context ) - if out.flags & 0x08: - out.first_ele_index = cursor.field( - chunk, _rntuple_first_ele_index_format, context + if out.flags & uproot.const.RNTupleColumnFlag.DEFERRED: + out.first_element_index = cursor.field( + chunk, _rntuple_first_element_index_format, context + ) + else: + out.first_element_index = 0 + if out.flags & uproot.const.RNTupleColumnFlag.RANGE: + out.min_value, out.max_value = cursor.fields( + chunk, _rntuple_column_range_format, context ) else: - out.first_ele_index = 0 + out.min_value, out.max_value = None, None return out @@ -798,7 +1002,7 @@ class ExtraTypeInfoReader: def read(self, chunk, cursor, context): out = MetaData("ExtraTypeInfoReader") - out.type_ver_from, out.type_ver_to, out.content_id = cursor.fields( + out.content_id, out.type_ver_from, out.type_ver_to = cursor.fields( chunk, _rntuple_extra_type_info_format, context ) out.type_name = cursor.rntuple_string(chunk, context) @@ -823,6 +1027,9 @@ def __init__(self): def read(self, chunk, cursor, context): out = MetaData(type(self).__name__) out.env_header = _envelop_header(chunk, cursor, context) + assert ( + out.env_header["env_type_id"] == uproot.const.RNTupleEnvelopeType.HEADER + ), f"env_type_id={out.env_header['env_type_id']}" out.feature_flag = cursor.field(chunk, _rntuple_feature_flag_format, context) out.name, out.ntuple_description, out.writer_identifier = ( cursor.rntuple_string(chunk, context) for _ in range(3) @@ -863,6 +1070,10 @@ def read(self, chunk, cursor, context): out.num_first_entry, out.num_entries = cursor.fields( chunk, _rntuple_cluster_summary_format, context ) + out.flags = out.num_entries >> 56 + out.num_entries &= 0xFFFFFFFFFFFFFF + if out.flags == uproot.const.RNTupleClusterFlag.SHARDED: + raise NotImplementedError("Sharded clusters are not supported.") return out @@ -879,7 +1090,7 @@ def read(self, chunk, cursor, context): class RNTupleSchemaExtension: def read(self, chunk, cursor, context): out = MetaData(type(self).__name__) - out.size = cursor.field(chunk, _rntuple_record_size_format, context) + out.size = cursor.field(chunk, _rntuple_frame_size_format, context) assert out.size >= 0, f"size={out.size}" out.field_records = ListFrameReader( RecordFrameReader(FieldRecordReader()) @@ -908,11 +1119,13 @@ def __init__(self): self.cluster_group_record_frames = ListFrameReader( RecordFrameReader(ClusterGroupRecordReader()) ) - self.meta_data_links = ListFrameReader(RecordFrameReader(EnvLinkReader())) def read(self, chunk, cursor, context): out = MetaData("Footer") out.env_header = _envelop_header(chunk, cursor, context) + assert ( + out.env_header["env_type_id"] == uproot.const.RNTupleEnvelopeType.FOOTER + ), f"env_type_id={out.env_header['env_type_id']}" out.feature_flag = cursor.field(chunk, _rntuple_feature_flag_format, context) out.header_checksum = cursor.field(chunk, _rntuple_checksum_format, context) out.extension_links = self.extension_header_links.read(chunk, cursor, context) @@ -922,11 +1135,150 @@ def read(self, chunk, cursor, context): out.cluster_group_records = self.cluster_group_record_frames.read( chunk, cursor, context ) - out.meta_block_links = self.meta_data_links.read(chunk, cursor, context) out.checksum = cursor.field(chunk, _rntuple_checksum_format, context) return out +class RNTupleField: + def __init__(self, index, ntuple): + self.index = index + self.ntuple = ntuple + self._length = None + + @property + def _keys(self): + keys = [] + for i, fr in enumerate(self.ntuple.field_records): + if i == self.index: + continue + if ( + fr.parent_field_id == self.index + and fr.type_name != "" + and not fr.field_name.startswith("_") + and not fr.field_name.startswith(":_") + ): + keys.append(fr.field_name) + return keys + + def keys(self): + return self._keys + + @property + def name(self): + """ + Name of the ``Field``. + """ + return self.ntuple.field_records[self.index].field_name + + def __len__(self): + if self._length is None: + self._length = len(self.keys()) + return self._length + + def __repr__(self): + if len(self) == 0: + return f"" + else: + return f"" + + @property + def _key_indices(self): + indices = [] + field_records = self.ntuple.field_records + for i, fr in enumerate(field_records): + if fr.parent_field_id == self.index and fr.type_name != "": + indices.append(i) + return indices + + @property + def _key_to_index(self): + d = {} + field_records = self.ntuple.field_records + for i, fr in enumerate(field_records): + if fr.parent_field_id == self.index and fr.type_name != "": + d[fr.field_name] = i + return d + + def __getitem__(self, where): + # original_where = where + + if uproot._util.isint(where): + index = self._key_indices[where] + elif isinstance(where, str): + where = uproot._util.ensure_str(where) + index = self._key_to_index[where] + else: + raise TypeError(f"where must be an integer or a string, not {where!r}") + + # TODO: Implement path support + + return RNTupleField(index, self.ntuple) + + def to_akform(self): + ak = uproot.extras.awkward() + + field_records = self.ntuple.field_records + recordlist = [] + topnames = self.keys() + if len(topnames) == 0: + topnames = [self.name] + recordlist.append(self.ntuple.field_form(self.index, set())) + else: + seen = set() + for i in range(len(field_records)): + if ( + i not in seen + and field_records[i].parent_field_id == self.index + and i != self.index + and not field_records[i].field_name.startswith("_") + and not field_records[i].field_name.startswith(":_") + ): + ff = self.ntuple.field_form(i, seen) + if field_records[i].type_name != "": + recordlist.append(ff) + + form = ak.forms.RecordForm(recordlist, topnames, form_key="toplevel") + return form + + def arrays( + self, + filter_name="*", + filter_typename=None, + entry_start=0, + entry_stop=None, + decompression_executor=None, + array_cache=None, + ): + return _arrays( + self, + filter_name=filter_name, + filter_typename=filter_typename, + entry_start=entry_start, + entry_stop=entry_stop, + decompression_executor=decompression_executor, + array_cache=array_cache, + ) + + def array(self, **kwargs): + if len(self.keys()) == 0: + return self.arrays(**kwargs)[self.name] + return self.arrays(**kwargs) + + def __array__(self, *args, **kwargs): + out = self.array() + if args == () and kwargs == {}: + return out + else: + return numpy.array(out, *args, **kwargs) + + def iterate(self, filter_name="*", *args, step_size="100 MB", **kwargs): + step_size = _regularize_step_size(self, step_size, filter_name) + for start in range(0, self.ntuple.num_entries, step_size): + yield self.array( + *args, entry_start=start, entry_stop=start + step_size, **kwargs + ) + + uproot.classes["ROOT::Experimental::RNTuple"] = ( Model_ROOT_3a3a_Experimental_3a3a_RNTuple ) diff --git a/src/uproot/writing/_cascadentuple.py b/src/uproot/writing/_cascadentuple.py index bdbb795e3..8fcb39142 100644 --- a/src/uproot/writing/_cascadentuple.py +++ b/src/uproot/writing/_cascadentuple.py @@ -29,8 +29,8 @@ _rntuple_column_record_format, _rntuple_feature_flag_format, _rntuple_field_description_format, - _rntuple_locator_format, - _rntuple_record_size_format, + _rntuple_frame_size_format, + _rntuple_locator_size_format, ) from uproot.writing._cascade import CascadeLeaf, CascadeNode, Key, String @@ -114,7 +114,7 @@ def _record_frame_wrap(payload, includeself=True): aloc = len(payload) if includeself: aloc += 4 - raw_bytes = _rntuple_record_size_format.pack(aloc) + payload + raw_bytes = _rntuple_frame_size_format.pack(aloc) + payload return raw_bytes @@ -377,7 +377,7 @@ def __init__(self, num_bytes, offset): self.offset = offset def serialize(self): - outbytes = _rntuple_locator_format.pack(self.num_bytes, self.offset) + outbytes = _rntuple_locator_size_format.pack(self.num_bytes, self.offset) return outbytes def __repr__(self): @@ -532,7 +532,7 @@ def __init__(self, num_bytes, offset): self.offset = offset def serialize(self): - outbytes = _rntuple_locator_format.pack(self.num_bytes, self.offset) + outbytes = _rntuple_locator_size_format.pack(self.num_bytes, self.offset) return outbytes def __repr__(self): diff --git a/tests-wasm/test_1272_basic_functionality.py b/tests-wasm/test_1272_basic_functionality.py index ed0495508..1cdc5163e 100644 --- a/tests-wasm/test_1272_basic_functionality.py +++ b/tests-wasm/test_1272_basic_functionality.py @@ -79,7 +79,7 @@ def test_read_rntuple(selenium): assert len(obj.column_records) > len(obj.header.column_records) assert len(obj.column_records) == 936 - assert obj.column_records[903].first_ele_index == 36 + assert obj.column_records[903].first_element_index == 36 arrays = obj.arrays() diff --git a/tests/test_0013_rntuple_anchor.py b/tests/test_0013_rntuple_anchor.py index 224fa128e..14133b93b 100644 --- a/tests/test_0013_rntuple_anchor.py +++ b/tests/test_0013_rntuple_anchor.py @@ -25,7 +25,7 @@ def test(): assert obj.member("fSeekFooter") == 36420 assert obj.member("fNBytesFooter") == 89 assert obj.member("fLenFooter") == 172 - assert obj.member("fChecksum") == 12065027575882477574 + assert obj.member("fMaxKeySize") == 12065027575882477574 header_start = obj.member("fSeekHeader") header_stop = header_start + obj.member("fNBytesHeader") diff --git a/tests/test_1191_rntuple_fixes.py b/tests/test_1191_rntuple_fixes.py index a1e259310..1aedf4093 100644 --- a/tests/test_1191_rntuple_fixes.py +++ b/tests/test_1191_rntuple_fixes.py @@ -13,7 +13,7 @@ def test_schema_extension(): assert len(obj.column_records) > len(obj.header.column_records) assert len(obj.column_records) == 936 - assert obj.column_records[903].first_ele_index == 36 + assert obj.column_records[903].first_element_index == 36 arrays = obj.arrays() diff --git a/tests/test_1250_rntuple_improvements.py b/tests/test_1250_rntuple_improvements.py new file mode 100644 index 000000000..8a169ab86 --- /dev/null +++ b/tests/test_1250_rntuple_improvements.py @@ -0,0 +1,85 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/uproot5/blob/main/LICENSE + +import pytest +import skhep_testdata + +import uproot + + +def test_field_class(): + filename = skhep_testdata.data_path("DAOD_TRUTH3_RC2.root") + with uproot.open(filename) as f: + obj = f["RNT:CollectionTree"] + jets = obj["AntiKt4TruthDressedWZJetsAux:"] + assert len(jets) == 6 + + pt = jets["pt"] + assert len(pt) == 0 + + +def test_array_methods(): + filename = skhep_testdata.data_path( + "Run2012BC_DoubleMuParked_Muons_rntuple_1000evts.root" + ) + with uproot.open(filename) as f: + obj = f["Events"] + nMuon_array = obj["nMuon"].array() + Muon_pt_array = obj["Muon_pt"].array() + assert nMuon_array.tolist() == [len(l) for l in Muon_pt_array] + + nMuon_arrays = obj["nMuon"].arrays() + assert len(nMuon_arrays.fields) == 1 + assert len(nMuon_arrays) == 1000 + assert nMuon_arrays["nMuon"].tolist() == nMuon_array.tolist() + + filename = skhep_testdata.data_path("DAOD_TRUTH3_RC2.root") + with uproot.open(filename) as f: + obj = f["RNT:CollectionTree"] + jets = obj["AntiKt4TruthDressedWZJetsAux:"].arrays() + assert len(jets[0].pt) == 5 + + +def test_iterate(): + filename = skhep_testdata.data_path( + "Run2012BC_DoubleMuParked_Muons_rntuple_1000evts.root" + ) + with uproot.open(filename) as f: + obj = f["Events"] + for i, arrays in enumerate(obj.iterate(step_size=100)): + assert len(arrays) == 100 + if i == 0: + expected_pt = [10.763696670532227, 15.736522674560547] + expected_charge = [-1, -1] + assert arrays["Muon_pt"][0].tolist() == expected_pt + assert arrays["Muon_charge"][0].tolist() == expected_charge + + for i, arrays in enumerate(obj.iterate(step_size="10 kB")): + if i == 0: + assert len(arrays) == 363 + expected_pt = [10.763696670532227, 15.736522674560547] + expected_charge = [-1, -1] + assert arrays["Muon_pt"][0].tolist() == expected_pt + assert arrays["Muon_charge"][0].tolist() == expected_charge + elif i == 1: + assert len(arrays) == 363 + elif i == 2: + assert len(arrays) == 274 + else: + assert False + + Muon_pt = obj["Muon_pt"] + for i, arrays in enumerate(Muon_pt.iterate(step_size=100)): + assert len(arrays) == 100 + if i == 0: + expected_pt = [10.763696670532227, 15.736522674560547] + assert arrays[0].tolist() == expected_pt + + for i, arrays in enumerate(Muon_pt.iterate(step_size="5 kB")): + if i == 0: + assert len(arrays) == 611 + expected_pt = [10.763696670532227, 15.736522674560547] + assert arrays[0].tolist() == expected_pt + elif i == 1: + assert len(arrays) == 389 + else: + assert False diff --git a/tests/test_1285_rntuple_multicluster_concatenation.py b/tests/test_1285_rntuple_multicluster_concatenation.py new file mode 100644 index 000000000..392d33f6e --- /dev/null +++ b/tests/test_1285_rntuple_multicluster_concatenation.py @@ -0,0 +1,20 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/uproot5/blob/main/LICENSE + +import skhep_testdata +import numpy as np + +import uproot + + +def test_schema_extension(): + filename = skhep_testdata.data_path("test_ntuple_index_multicluster.root") + with uproot.open(filename) as f: + obj = f["ntuple"] + + arrays = obj.arrays() + int_vec_array = arrays["int_vector"] + + for j in range(2): + for i in range(100): + assert int_vec_array[i + j * 100, 0] == i + assert int_vec_array[i + j * 100, 1] == i + j