Skip to content
This repository has been archived by the owner on Jun 21, 2022. It is now read-only.

Commit

Permalink
Merge pull request #379 from masonproffitt/issue376
Browse files Browse the repository at this point in the history
Provide method to get ROOT class names as plain strings
  • Loading branch information
jpivarski authored Nov 3, 2019
2 parents 7e220dd + 93a2956 commit 7052188
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 1 deletion.
17 changes: 17 additions & 0 deletions tests/test_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,23 @@ def test_issue371(self):
assert obj._n == 1
assert obj._energy[0] == numpy.array([2.3371024], dtype=numpy.float32)[0]

def test_issue376_simple(self):
f = uproot.open("tests/samples/from-geant4.root")
assert type(f).classname == 'TDirectory'
assert f.classname == 'TDirectory'
real_class_names = ['TTree'] * 4 + ['TH1D'] * 10 + ['TH2D'] * 5
assert [classname_two_tuple[1] for classname_two_tuple in f.classnames()] == real_class_names
assert [class_two_tuple[1].classname for class_two_tuple in f.classes()] == real_class_names
assert [value.classname for value in f.values()] == real_class_names

def test_issue376_nested(self):
f = uproot.open("tests/samples/nesteddirs.root")
top_level_class_names = ['TDirectory', 'TDirectory']
recursive_class_names = ['TDirectory', 'TDirectory', 'TTree', 'TTree', 'TDirectory', 'TTree']
assert [classname_two_tuple[1] for classname_two_tuple in f.classnames(recursive=False)] == top_level_class_names
assert [classname_two_tuple[1] for classname_two_tuple in f.classnames(recursive=True)] == recursive_class_names
assert [classname_two_tuple[1] for classname_two_tuple in f.allclassnames()] == recursive_class_names

def test_issue367(self):
t = uproot.open("tests/samples/issue367.root")["tree"]
assert awkward.fromiter(t.array("weights.second"))[0].counts.tolist() == [1000, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 100, 100, 100, 1]
Expand Down
112 changes: 111 additions & 1 deletion uproot/rootio.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class ROOTDirectory(object):
__metaclass__ = type.__new__(type, "type", (type,), {})

_classname = b"TDirectory"
classname = "TDirectory"

class _FileContext(object):
def __init__(self, sourcepath, streamerinfos, streamerinfosmap, classes, compression, tfile):
Expand Down Expand Up @@ -285,6 +286,16 @@ def iterclasses(self, recursive=False, filtername=nofilter, filterclass=nofilter
for name, classname in key.get().iterclasses(recursive, filtername, filterclass):
yield "{0}/{1}".format(self._withoutcycle(key).decode("ascii"), name.decode("ascii")).encode("ascii"), classname

def iterclassnames(self, recursive=False, filtername=nofilter, filterclass=nofilter):
for key in self._keys:
cls = _classof(self._context, key._fClassName)
if filtername(key._fName) and filterclass(cls):
yield self._withcycle(key), key._fClassName.decode('ascii')

if recursive and (key._fClassName == b"TDirectory" or key._fClassName == b"TDirectoryFile"):
for name, classname in key.get().iterclassnames(recursive, filtername, filterclass):
yield "{0}/{1}".format(self._withoutcycle(key).decode("ascii"), name.decode("ascii")).encode("ascii"), classname

def keys(self, recursive=False, filtername=nofilter, filterclass=nofilter):
return list(self.iterkeys(recursive=recursive, filtername=filtername, filterclass=filterclass))

Expand All @@ -301,6 +312,9 @@ def items(self, recursive=False, filtername=nofilter, filterclass=nofilter):
def classes(self, recursive=False, filtername=nofilter, filterclass=nofilter):
return list(self.iterclasses(recursive=recursive, filtername=filtername, filterclass=filterclass))

def classnames(self, recursive=False, filtername=nofilter, filterclass=nofilter):
return list(self.iterclassnames(recursive=recursive, filtername=filtername, filterclass=filterclass))

def allkeys(self, filtername=nofilter, filterclass=nofilter):
return self.keys(recursive=True, filtername=filtername, filterclass=filterclass)

Expand All @@ -313,6 +327,9 @@ def allitems(self, filtername=nofilter, filterclass=nofilter):
def allclasses(self, filtername=nofilter, filterclass=nofilter):
return self.classes(recursive=True, filtername=filtername, filterclass=filterclass)

def allclassnames(self, filtername=nofilter, filterclass=nofilter):
return self.classnames(recursive=True, filtername=filtername, filterclass=filterclass)

def get(self, name, cycle=None):
name = _bytesid(name)

Expand Down Expand Up @@ -872,6 +889,7 @@ def _defineclasses(streamerinfos, classes):
code.insert(0, " _hasreadobjany = {0}".format(hasreadobjany))
code.insert(0, " _classversion = {0}".format(streamerinfo._fClassVersion))
code.insert(0, " _versions = versions")
code.insert(0, " classname = {0}".format(repr(streamerinfo._fName.decode("ascii"))))
if sys.version_info[0] > 2:
code.insert(0, " _classname = {0}".format(repr(streamerinfo._fName)))
else:
Expand Down Expand Up @@ -946,6 +964,9 @@ def __repr__(self):
return "<{0} at 0x{1:012x}>".format(self.__class__.__name__, id(self))

class TKey(ROOTObject):
_classname = b"TKey"
classname = "TKey"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start = cursor.index
Expand Down Expand Up @@ -1034,6 +1055,9 @@ def _canonicaltype(name):
]

class TStreamerInfo(ROOTObject):
_classname = b"TStreamerInfo"
classname = "TStreamerInfo"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand All @@ -1055,6 +1079,9 @@ def show(self, stream=sys.stdout):
stream.write("\n")

class TStreamerElement(ROOTObject):
_classname = b"TStreamerElement"
classname = "TStreamerElement"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand Down Expand Up @@ -1104,6 +1131,9 @@ def show(self, stream=sys.stdout):
stream.write("\n")

class TStreamerArtificial(TStreamerElement):
_classname = b"TStreamerArtificial"
classname = "TStreamerArtificial"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand All @@ -1112,6 +1142,9 @@ def _readinto(cls, self, source, cursor, context, parent):
return self

class TStreamerBase(TStreamerElement):
_classname = b"TStreamerBase"
classname = "TStreamerBase"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand All @@ -1124,6 +1157,9 @@ def _readinto(cls, self, source, cursor, context, parent):
_format = struct.Struct(">i")

class TStreamerBasicPointer(TStreamerElement):
_classname = b"TStreamerBasicPointer"
classname = "TStreamerBasicPointer"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand All @@ -1137,6 +1173,9 @@ def _readinto(cls, self, source, cursor, context, parent):
_format = struct.Struct(">i")

class TStreamerBasicType(TStreamerElement):
_classname = b"TStreamerBasicType"
classname = "TStreamerBasicType"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand Down Expand Up @@ -1170,6 +1209,9 @@ def _readinto(cls, self, source, cursor, context, parent):
return self

class TStreamerLoop(TStreamerElement):
_classname = b"TStreamerLoop"
classname = "TStreamerLoop"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand All @@ -1183,6 +1225,9 @@ def _readinto(cls, self, source, cursor, context, parent):
_format = struct.Struct(">i")

class TStreamerObject(TStreamerElement):
_classname = b"TStreamerObject"
classname = "TStreamerObject"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand All @@ -1191,6 +1236,9 @@ def _readinto(cls, self, source, cursor, context, parent):
return self

class TStreamerObjectAny(TStreamerElement):
_classname = b"TStreamerObjectAny"
classname = "TStreamerObjectAny"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand All @@ -1199,6 +1247,9 @@ def _readinto(cls, self, source, cursor, context, parent):
return self

class TStreamerObjectAnyPointer(TStreamerElement):
_classname = b"TStreamerObjectAnyPointer"
classname = "TStreamerObjectAnyPointer"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand All @@ -1207,6 +1258,9 @@ def _readinto(cls, self, source, cursor, context, parent):
return self

class TStreamerObjectPointer(TStreamerElement):
_classname = b"TStreamerObjectPointer"
classname = "TStreamerObjectPointer"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand All @@ -1215,6 +1269,9 @@ def _readinto(cls, self, source, cursor, context, parent):
return self

class TStreamerSTL(TStreamerElement):
_classname = b"TStreamerSTL"
classname = "TStreamerSTL"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand Down Expand Up @@ -1242,6 +1299,9 @@ def vector(cls, fType, fTypeName):
_format = struct.Struct(">ii")

class TStreamerSTLstring(TStreamerSTL):
_classname = b"TStreamerSTLstring"
classname = "TStreamerSTLstring"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand All @@ -1250,6 +1310,9 @@ def _readinto(cls, self, source, cursor, context, parent):
return self

class TStreamerString(TStreamerElement):
_classname = b"TStreamerString"
classname = "TStreamerString"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand Down Expand Up @@ -1302,6 +1365,9 @@ def _recarray_dtype(cls, cntvers=False, tobject=True):
return numpy.dtype(dtypesout)

class TObject(ROOTStreamedObject):
_classname = b"TObject"
classname = "TObject"

@classmethod
def _recarray(cls):
return [(" fBits", numpy.dtype(">u8")), (" fUniqueID", numpy.dtype(">u8"))]
Expand All @@ -1312,6 +1378,9 @@ def _readinto(cls, self, source, cursor, context, parent):
return self

class TString(bytes, ROOTStreamedObject):
_classname = b"TString"
classname = "TString"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
return TString(cursor.string(source))
Expand All @@ -1320,6 +1389,8 @@ def __str__(self):
return self.decode("utf-8", "replace")

class TNamed(TObject):
_classname = b"TNamed"
classname = "TNamed"
_fields = ["fName", "fTitle"]

@classmethod
Expand All @@ -1336,6 +1407,9 @@ def _readinto(cls, self, source, cursor, context, parent):
return self

class TObjArray(list, ROOTStreamedObject):
_classname = b"TObjArray"
classname = "TObjArray"

@classmethod
def read(cls, source, cursor, context, parent, asclass=None):
if cls._copycontext:
Expand All @@ -1356,6 +1430,9 @@ def _readinto(cls, self, source, cursor, context, parent, asclass=None):
return self

class TObjString(bytes, ROOTStreamedObject):
_classname = b"TObjString"
classname = "TObjString"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand All @@ -1368,6 +1445,9 @@ def __str__(self):
return self.decode("utf-8", "replace")

class TList(list, ROOTStreamedObject):
_classname = b"TList"
classname = "TList"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
Expand All @@ -1383,12 +1463,18 @@ def _readinto(cls, self, source, cursor, context, parent):
_format_n = struct.Struct(">B")

class THashList(TList):
_classname = b"THashList"
classname = "THashList"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
TList._readinto(self, source, cursor, context, parent)
return self

class TRef(ROOTStreamedObject):
_classname = b"TRef"
classname = "TRef"

_format1 = struct.Struct(">xxIxxxxxx")

def __init__(self, id):
Expand All @@ -1415,6 +1501,9 @@ def _recarray(cls):
TRef._fromrow = lambda row: TRef(row["id"])

class TRefArray(list, ROOTStreamedObject):
_classname = b"TRefArray"
classname = "TRefArray"

_format1 = struct.Struct(">i")
_dtype = numpy.dtype(">i4")

Expand All @@ -1437,6 +1526,9 @@ def tostring(self):
return numpy.asarray(self, dtype=self._dtype).tostring()

class TArray(list, ROOTStreamedObject):
_classname = b"TArray"
classname = "TArray"

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
length = cursor.field(source, TArray._format)
Expand All @@ -1452,24 +1544,38 @@ def tostring(self):
return numpy.asarray(self, dtype=self._dtype).tostring()

class TArrayC(TArray):
_classname = b"TArrayC"
classname = "TArrayC"
_dtype = numpy.dtype(">i1")

class TArrayS(TArray):
_classname = b"TArrayS"
classname = "TArrayS"
_dtype = numpy.dtype(">i2")

class TArrayI(TArray):
_classname = b"TArrayI"
classname = "TArrayI"
_dtype = numpy.dtype(">i4")

class TArrayL(TArray):
_classname = b"TArrayL"
classname = "TArrayL"
_dtype = numpy.dtype(numpy.int_).newbyteorder(">")

class TArrayL64(TArray):
_classname = b"TArrayL64"
classname = "TArrayL64"
_dtype = numpy.dtype(">i8")

class TArrayF(TArray):
_classname = b"TArrayF"
classname = "TArrayF"
_dtype = numpy.dtype(">f4")

class TArrayD(TArray):
_classname = b"TArrayD"
classname = "TArrayD"
_dtype = numpy.dtype(">f8")

# FIXME: I want to generalize this. It's the first example of a class that doesn't
Expand All @@ -1482,19 +1588,23 @@ class TArrayD(TArray):
# I'm also reasonably certain that the last byte is the fIOBits data.
# That leaves 4 bytes unaccounted for.
class ROOT_3a3a_TIOFeatures(ROOTStreamedObject):
_fields = ["fIOBits"]
_classname = b"ROOT::TIOFeatures"
classname = "ROOT::TIOFeatures"
_fields = ["fIOBits"]

@classmethod
def _readinto(cls, self, source, cursor, context, parent):
start, cnt, self._classversion = _startcheck(source, cursor)
cursor.skip(4)
self._fIOBits = cursor.field(source, ROOT_3a3a_TIOFeatures._format1)
_endcheck(start, cursor, cnt)
return self

_format1 = struct.Struct(">B")

class Undefined(ROOTStreamedObject):
_classname = None
classname = None

@classmethod
def read(cls, source, cursor, context, parent, classname=None):
Expand Down

0 comments on commit 7052188

Please sign in to comment.