Skip to content
This repository has been archived by the owner on Jun 21, 2022. It is now read-only.

Add support to write Jagged Arrays to TTrees #477

Merged
merged 10 commits into from
May 1, 2020
279 changes: 279 additions & 0 deletions tests/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import pytest
import numpy
import ctypes

import awkward

import uproot
from uproot.write.objects.TTree import newtree, newbranch
Expand Down Expand Up @@ -1855,3 +1858,279 @@ def test_tree_threedim(tmp_path):
for j in range(3):
for k in range(4):
assert a[i][j][k] == test[j][k]

def test_jagged_i4(tmp_path):
filename = join(str(tmp_path), "example.root")

a = awkward.fromiter([[0],
[1, 2],
[10, 11, 12]])

with uproot.recreate(filename, compression=None) as f:
f["t"] = uproot.newtree({"branch": uproot.newbranch(numpy.dtype(">i4"), counter="n")})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename "counter" to "size"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened PR #481

f["t"].extend({"branch": a, "n": [1, 2, 3]})

f = ROOT.TFile.Open(filename)
tree = f.Get("t")
for i, event in enumerate(tree):
assert(numpy.all([x for x in event.branch] == a[i]))

def test_jagged_uproot_i4(tmp_path):
filename = join(str(tmp_path), "example.root")

a = awkward.fromiter([[0],
[1, 2],
[10, 11, 12]])

with uproot.recreate(filename, compression=None) as f:
f["t"] = uproot.newtree({"branch": uproot.newbranch(numpy.dtype(">i4"), counter="n")})
f["t"].extend({"branch": a, "n": [1, 2, 3]})

f = uproot.open(filename)
array = f["t"].array(["branch"])
for i in range(len(array)):
for j in range(len(array[i])):
assert(array[i][j] == a[i][j])

#Need to use C++ code to read out because of bug in PyROOT layer (of Conda ROOT build?)
def test_jagged_i8(tmp_path):
filename = join(str(tmp_path), "example.root")

a = awkward.fromiter([[0],
[1, 2],
[10, 11, 12]])

with uproot.recreate(filename, compression=None) as f:
f["t"] = uproot.newtree({"branch": uproot.newbranch(numpy.dtype(">i8"), counter="n")})
f["t"].extend({"branch": a, "n": [1, 2, 3]})

ROOT.gInterpreter.Declare("""
void assertint(bool &flag, char* filename) {
TFile *f = new TFile(filename);
Long64_t x;
Int_t num;
auto tree = f->Get<TTree>("t");
auto n = tree->GetBranch("n");
auto branch = tree->GetBranch("branch");
n->SetAddress(&num);
branch->SetAddress(&x);
Long64_t values[3][3] = {{0,0,0}, {1, 2, 0}, {10, 11, 12}};
for (int i=0; i<tree->GetEntries(); i++) {
tree->GetEvent(i);
for (int j=0; j<num; j++) {
if (values[i][j] != x+j)
flag = false;
}
}
}""")

flag = ctypes.c_bool(True)
ROOT.assertint(flag, filename)
assert(flag)

def test_jagged_uproot_i8(tmp_path):
filename = join(str(tmp_path), "example.root")

a = awkward.fromiter([[0],
[1, 2],
[10, 11, 12]])

with uproot.recreate(filename, compression=None) as f:
f["t"] = uproot.newtree({"branch": uproot.newbranch(numpy.dtype(">i8"), counter="n")})
f["t"].extend({"branch": a, "n": [1, 2, 3]})

f = uproot.open(filename)
array = f["t"].array(["branch"])
for i in range(len(array)):
for j in range(len(array[i])):
assert(array[i][j] == a[i][j])

def test_jagged_f8(tmp_path):
filename = join(str(tmp_path), "example.root")

a = awkward.fromiter([[0],
[1, 2],
[10, 11, 12]])

with uproot.recreate(filename, compression=None) as f:
f["t"] = uproot.newtree({"branch": uproot.newbranch(numpy.dtype(">f8"), counter="n")})
f["t"].extend({"branch": a, "n": [1, 2, 3]})

f = ROOT.TFile.Open(filename)
tree = f.Get("t")
for i, event in enumerate(tree):
assert(numpy.all([x for x in event.branch] == a[i]))

def test_jagged_uproot_f8(tmp_path):
filename = join(str(tmp_path), "example.root")

a = awkward.fromiter([[0],
[1, 2],
[10, 11, 12]])

with uproot.recreate(filename, compression=None) as f:
f["t"] = uproot.newtree({"branch": uproot.newbranch(numpy.dtype(">f8"), counter="n")})
f["t"].extend({"branch": a, "n": [1, 2, 3]})

f = uproot.open(filename)
array = f["t"].array(["branch"])
for i in range(len(array)):
for j in range(len(array[i])):
assert(array[i][j] == a[i][j])

def test_jagged_f4(tmp_path):
filename = join(str(tmp_path), "example.root")

a = awkward.fromiter([[0],
[1, 2],
[10, 11, 12]])

with uproot.recreate(filename, compression=None) as f:
f["t"] = uproot.newtree({"branch": uproot.newbranch(numpy.dtype(">f4"), counter="n")})
f["t"].extend({"branch": a, "n": [1, 2, 3]})

f = ROOT.TFile.Open(filename)
tree = f.Get("t")
for i, event in enumerate(tree):
assert(numpy.all([x for x in event.branch] == a[i]))

def test_jagged_uproot_f4(tmp_path):
filename = join(str(tmp_path), "example.root")

a = awkward.fromiter([[0],
[1, 2],
[10, 11, 12]])

with uproot.recreate(filename, compression=None) as f:
f["t"] = uproot.newtree({"branch": uproot.newbranch(numpy.dtype(">f4"), counter="n")})
f["t"].extend({"branch": a, "n": [1, 2, 3]})

f = uproot.open(filename)
array = f["t"].array(["branch"])
for i in range(len(array)):
for j in range(len(array[i])):
assert(array[i][j] == a[i][j])

def test_jagged_i2(tmp_path):
filename = join(str(tmp_path), "example.root")

a = awkward.fromiter([[0],
[1, 2],
[10, 11, 12]])

with uproot.recreate(filename, compression=None) as f:
f["t"] = uproot.newtree({"branch": uproot.newbranch(numpy.dtype(">i2"), counter="n")})
f["t"].extend({"branch": a, "n": [1, 2, 3]})

f = ROOT.TFile.Open(filename)
tree = f.Get("t")
for i, event in enumerate(tree):
assert(numpy.all([x for x in event.branch] == a[i]))

def test_jagged_uproot_i2(tmp_path):
filename = join(str(tmp_path), "example.root")

a = awkward.fromiter([[0],
[1, 2],
[10, 11, 12]])

with uproot.recreate(filename, compression=None) as f:
f["t"] = uproot.newtree({"branch": uproot.newbranch(numpy.dtype(">i2"), counter="n")})
f["t"].extend({"branch": a, "n": [1, 2, 3]})

f = uproot.open(filename)
array = f["t"].array(["branch"])
for i in range(len(array)):
for j in range(len(array[i])):
assert(array[i][j] == a[i][j])

def test_jagged_i2_multiple_sametype(tmp_path):
filename = join(str(tmp_path), "example.root")

a = awkward.fromiter([[0],
[1, 2]])

b = awkward.fromiter([[3],
[7, 12]])

with uproot.recreate(filename, compression=None) as f:
f["t"] = uproot.newtree({"branch1": uproot.newbranch(numpy.dtype(">i2"), counter="n"),
"branch2": uproot.newbranch(numpy.dtype(">i2"), counter="n")})
f["t"].extend({"branch1": a,
"branch2": b,
"n": [1, 2]})

f = ROOT.TFile.Open(filename)
tree = f.Get("t")
for i, event in enumerate(tree):
assert(numpy.all([x for x in event.branch1] == a[i]))
assert(numpy.all([x for x in event.branch2] == b[i]))

def test_jagged_multiple_difftype(tmp_path):
filename = join(str(tmp_path), "example.root")

a = awkward.fromiter([[0],
[1, 2]])

b = awkward.fromiter([[3],
[7, 12]])

with uproot.recreate(filename, compression=None) as f:
f["t"] = uproot.newtree({"branch1": uproot.newbranch(numpy.dtype(">i2"), counter="n"),
"branch2": uproot.newbranch(numpy.dtype(">i4"), counter="n")})
f["t"].extend({"branch1": a,
"branch2": b,
"n": [1, 2]})

f = ROOT.TFile.Open(filename)
tree = f.Get("t")
for i, event in enumerate(tree):
assert(numpy.all([x for x in event.branch1] == a[i]))
assert(numpy.all([x for x in event.branch2] == b[i]))

def test_jagged_i2_multiple_difflen(tmp_path):
filename = join(str(tmp_path), "example.root")

a = awkward.fromiter([[0],
[1, 2]])

b = awkward.fromiter([[3],
[10, 11, 12]])

with uproot.recreate(filename, compression=None) as f:
f["t"] = uproot.newtree({"branch1": uproot.newbranch(numpy.dtype(">i2"), counter="n1"),
"branch2": uproot.newbranch(numpy.dtype(">i2"), counter="n2")})
f["t"].extend({"branch1": a,
"n1": [1, 2],
"branch2": b,
"n2": [1, 3]})

f = ROOT.TFile.Open(filename)
tree = f.Get("t")
for i, event in enumerate(tree):
assert(numpy.all([x for x in event.branch1] == a[i]))
assert(numpy.all([x for x in event.branch2] == b[i]))

def test_jagged_i4_manybasket(tmp_path):
filename = join(str(tmp_path), "example.root")

a = awkward.fromiter([[0],
[1, 2],
[10, 11, 12]])
b = awkward.fromiter([[10],
[11, 12]])
tester = awkward.fromiter([[0],
[1, 2],
[10, 11, 12],
[10],
[11, 12]])

with uproot.recreate(filename, compression=None) as f:
f["t"] = uproot.newtree({"branch": uproot.newbranch(numpy.dtype(">i4"), counter="n")})
f["t"].extend({"branch": a, "n": [1, 2, 3]})
f["t"].extend({"branch": b, "n": [1, 2]})

f = ROOT.TFile.Open(filename)
tree = f.Get("t")
for i, event in enumerate(tree):
assert(numpy.all([x for x in event.branch] == tester[i]))
16 changes: 13 additions & 3 deletions uproot/write/TKey.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(self, fName, fTitle, fNevBuf, fNevBufSize, fObjlen=0, fSeekKey=100,
self.fNevBuf = fNevBuf
self.fNevBufSize = fNevBufSize

self.old_fLast = 0

@property
def fKeylen(self):
return self._format1.size + uproot.write.sink.cursor.Cursor.length_strings([self.fClassName, self.fName, self.fTitle]) + self._format_basketkey.size + 1
Expand All @@ -36,7 +38,7 @@ def fLast(self):
def update(self):
self.cursor.update_fields(self.sink, self._format1, self.fNbytes, self._version, self.fObjlen, self.fDatime, self.fKeylen, self.fCycle, self.fSeekKey, self.fSeekPdir)

def write(self, cursor, sink):
def write(self, cursor, sink, isjagged=False):
self.cursor = uproot.write.sink.cursor.Cursor(cursor.index)
self.sink = sink

Expand All @@ -48,7 +50,13 @@ def write(self, cursor, sink):
cursor.write_string(sink, self.fTitle)

basketversion = 3
cursor.write_fields(sink, self._format_basketkey, basketversion, self.fBufferSize, self.fNevBufSize, self.fNevBuf, self.fLast)
if isjagged:
if self.old_fLast == 0:
raise Exception("isjagged flag should be False")
cursor.write_fields(sink, self._format_basketkey, basketversion, self.fBufferSize, self.fNevBufSize, self.fNevBuf, self.old_fLast)
else:
cursor.write_fields(sink, self._format_basketkey, basketversion, self.fBufferSize, self.fNevBufSize, self.fNevBuf, self.fLast)
self.old_fLast = self.fLast
cursor.write_data(sink, b"\x00")

_version = 1004
Expand All @@ -75,7 +83,9 @@ def fKeylen(self):
def update(self):
self.cursor.update_fields(self.sink, self._format1, self.fNbytes, self._version, self.fObjlen, self.fDatime, self.fKeylen, self.fCycle, self.fSeekKey, self.fSeekPdir)

def write(self, cursor, sink):
def write(self, cursor, sink, isjagged=False):
if isjagged:
raise Exception("isjagged flag should be False")
self.cursor = uproot.write.sink.cursor.Cursor(cursor.index)
self.sink = sink

Expand Down
21 changes: 12 additions & 9 deletions uproot/write/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class LZ4(Compression): pass
uproot.const.kLZMA: LZMA,
uproot.const.kLZ4: LZ4}

def write(context, cursor, givenbytes, compression, key, keycursor):
def write(context, cursor, givenbytes, compression, key, keycursor, isjagged=False):
retaincursor = copy.copy(keycursor)
if compression is None:
algorithm, level = 0, 0
Expand All @@ -64,9 +64,12 @@ def write(context, cursor, givenbytes, compression, key, keycursor):
uncompressedbytes = len(givenbytes)

if algorithm == 0 or level == 0:
key.fObjlen = uncompressedbytes
if isjagged:
key.fObjlen += uncompressedbytes
else:
key.fObjlen = uncompressedbytes
key.fNbytes = key.fObjlen + key.fKeylen
key.write(keycursor, context._sink)
key.write(keycursor, context._sink, isjagged)
cursor.write_data(context._sink, givenbytes)
return

Expand Down Expand Up @@ -94,10 +97,10 @@ def write(context, cursor, givenbytes, compression, key, keycursor):
cursor.write_fields(context._sink, _header, algo, method, c1, c2, c3, u1, u2, u3)
cursor.write_data(context._sink, after_compressed)
key.fNbytes += compressedbytes + 9
key.write(keycursor, context._sink)
key.write(keycursor, context._sink, isjagged)
else:
key.fNbytes += uncompressedbytes
key.write(keycursor, context._sink)
key.write(keycursor, context._sink, isjagged)
cursor.write_data(context._sink, givenbytes)

elif algorithm == uproot.const.kLZ4:
Expand Down Expand Up @@ -125,10 +128,10 @@ def write(context, cursor, givenbytes, compression, key, keycursor):
cursor.write_data(context._sink, checksum)
cursor.write_data(context._sink, after_compressed)
key.fNbytes += compressedbytes + 9
key.write(keycursor, context._sink)
key.write(keycursor, context._sink, isjagged)
else:
key.fNbytes += uncompressedbytes
key.write(keycursor, context._sink)
key.write(keycursor, context._sink, isjagged)
cursor.write_data(context._sink, givenbytes)

elif algorithm == uproot.const.kLZMA:
Expand All @@ -151,10 +154,10 @@ def write(context, cursor, givenbytes, compression, key, keycursor):
cursor.write_fields(context._sink, _header, algo, method, c1, c2, c3, u1, u2, u3)
cursor.write_data(context._sink, after_compressed)
key.fNbytes += compressedbytes + 9
key.write(keycursor, context._sink)
key.write(keycursor, context._sink, isjagged)
else:
key.fNbytes += uncompressedbytes
key.write(keycursor, context._sink)
key.write(keycursor, context._sink, isjagged)
cursor.write_data(context._sink, givenbytes)

elif algorithm == uproot.const.kOldCompressionAlgo:
Expand Down
Loading