Skip to content
This repository has been archived by the owner on Jun 21, 2022. It is now read-only.

Speed improvements and numba jit failures #41

Closed
vkuznet opened this issue Dec 21, 2017 · 2 comments
Closed

Speed improvements and numba jit failures #41

vkuznet opened this issue Dec 21, 2017 · 2 comments

Comments

@vkuznet
Copy link

vkuznet commented Dec 21, 2017

If I read /afs/cern.ch/user/v/valya/public/nano-RelValTTBar.root file with the following code:

from __future__ import print_function, division, absolute_import

# system modules
import os
import sys
import time

# numpy
import numpy as np

# uproot
import uproot

from numba import jit

#@jit
def read(fin, branch='Events'):
    normalBranches = []
    with uproot.open(fin) as istream:
        tree = istream[branch]
        for key, val in tree.allitems():
            data = val.array()
            if not isinstance(data, uproot.interp.jagged.JaggedArray):
                normalBranches.append(key)
        print("number of non-jagged branches %s" % len(normalBranches))
        time0 = time.time()
        for key in normalBranches:
            data = tree[key].array()
        print("elapsed time", time.time()-time0)

read('/opt/cms/data/nano-RelValTTBar.root')

the time I spent reading the non-jagged (I call them normal) branches is quite high, it is 1.8sec on my Mac. I understand that I can refactor code to access array once (I did it on purpose to demonstrate the time I read only non-jagged array, the first pass I identify non-jagged branches, in second pass I measure time to read them all).

If I apply jit decorator (commented out in a code) I get the following error:

Traceback (most recent call last):
  File "./test.py", line 40, in <module>
    read('/opt/cms/data/nano-RelValTTBar.root')
  File "/opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/numba/dispatcher.py", line 307, in _compile_for_args
    return self.compile(tuple(argtypes))
  File "/opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/numba/dispatcher.py", line 579, in compile
    cres = self._compiler.compile(args, return_type)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/numba/dispatcher.py", line 80, in compile
    flags=flags, locals=self.locals)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/numba/compiler.py", line 763, in compile_extra
    return pipeline.compile_extra(func)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/numba/compiler.py", line 360, in compile_extra
    return self._compile_bytecode()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/numba/compiler.py", line 722, in _compile_bytecode
    return self._compile_core()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/numba/compiler.py", line 709, in _compile_core
    res = pm.run(self.status)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/numba/compiler.py", line 246, in run
    raise patched_exception
AssertionError: Caused By:
Traceback (most recent call last):
  File "/opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/numba/compiler.py", line 238, in run
    stage()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/numba/compiler.py", line 374, in stage_analyze_bytecode
    func_ir = translate_stage(self.func_id, self.bc)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/numba/compiler.py", line 827, in translate_stage
    return interp.interpret(bytecode)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/numba/interpreter.py", line 92, in interpret
    self.cfa.run()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/numba/controlflow.py", line 515, in run
    assert not inst.is_jump, inst
AssertionError: SETUP_WITH(arg=176, lineno=28)

Failed at object (analyzing bytecode)
SETUP_WITH(arg=176, lineno=28)

I'm new to numba and probably doing something stupid, but I got the impression that code is Numba aware which can speed up things.

The main question is does this benchmark to read 583 branches in 1.8sec is expected speed, can it be improved? If so, how?

The main use case here is to develop a reader which will read physics events. Therefore I need to apply cache for normal and jagged branches. As I pointed out in #40 the cache parameter cause an error, I can use my own cache to store data, but depending on dimension of branches and number of events this can cause large memory overhead. And, spending almost 2sec per event to
read branches is somewhat "poor" performance to me.

@jpivarski
Copy link
Member

Two seconds per event would be terrible, but there are 9000 events in this tree, so you're actually loading 9000 events in the two-second timespan— the interpretation of the two-second result is off by four orders of magnitude. :)

I did the same thing and got about 1‒2 seconds on my Chromebook (dependent on run, always with warmed OS page cache). I used the TTree.arrays function to get them all at once.

import time
import uproot
tree = uproot.open("~/storage/data/nano-RelValTTBar.root")["Events"]

def performance(tree, allarrays, startTime, endTime):
    nbytes = sum(x.contents.nbytes + x.stops.nbytes if isinstance(x, JaggedArray) else x.nbytes
                 for x in allarrays.values())
    print "# {} entries, {} branches, {} MB, {} sec, {} MB/sec, {} kHz".format(
        tree.numentries,
        len(allarrays),
        nbytes / 1024**2,
        endTime - startTime,
        nbytes / 1024**2 / (endTime - startTime),
        tree.numentries / (endTime - startTime) / 1000)

startTime = time.time()
allarrays = tree.arrays()    # <-- read all arrays, right here
endTime = time.time()
performance(tree, allarrays, startTime, endTime)
# 9000 entries, 684 branches, 39 MB, 1.05251812935 sec, 37.0539935727 MB/sec, 8.55092159369 kHz

The rate scales most strongly with number of entries and number of branches, which is natural because their product is the amount of data read. Therefore, the most solid number is number of bytes per unit time, and 37 MB/sec strikes me as reasonable.

Note that we can also do fine-grained selections with the branches argument to pick out just the jagged ones or just the flat ones (before reading; your code reads first then discards what you don't want— even in your for-loop structure, you can check to see whether uproot.interpret(branch) is asdtype or asjagged).

passjagged = lambda x: x if isinstance(x, uproot.interp.asjagged) else None
passflat   = lambda x: x if isinstance(x, uproot.interp.asdtype)  else None

startTime = time.time()
allarrays = tree.arrays(branches=lambda branch: passjagged(uproot.interpret(branch)))
endTime = time.time()
performance(tree, allarrays, startTime, endTime)
# 9000 entries, 121 branches, 33 MB, 0.761940956116 sec, 43.3104425417 MB/sec, 11.811938875 kHz

startTime = time.time()
allarrays = tree.arrays(branches=lambda branch: passflat(uproot.interpret(branch)))
endTime = time.time()
performance(tree, allarrays, startTime, endTime)
# 9000 entries, 563 branches, 5 MB, 0.340945959091 sec, 14.6650806871 MB/sec, 26.3971452367 kHz

That dip when reading only flat branches is interesting. Investigating further, most of them are all boolean because they're trigger decisions. Running the tree.arrays function with different arguments through cProfile.run (sorted by tottime), I find that the time is normally dominated by a Numpy memmap.py:334(__getitem__) function (good), but for these branches it is surpassed by construction of a lock for a ThreadSafeDict in TTree._threadsafe_iterate_keys. This is something for me to think about improving: it's setup overhead dominating over the big data stream.

Meanwhile, it also means that the overhead can be cured by not loading those keys more than once:

keycache = {}

startTime = time.time()
allarrays = tree.arrays(branches=lambda branch: passflat(uproot.interpret(branch)),
                        keycache=keycache)
endTime = time.time()
performance(tree, allarrays, startTime, endTime)

startTime = time.time()
allarrays = tree.arrays(branches=lambda branch: passflat(uproot.interpret(branch)),
                        keycache=keycache)
endTime = time.time()
performance(tree, allarrays, startTime, endTime)

yields

# 9000 entries, 563 branches, 5 MB, 0.295423030853 sec, 16.9248822123 MB/sec, 30.4647879822 kHz
# 9000 entries, 563 branches, 5 MB, 0.171831846237 sec, 29.0982149671 MB/sec, 52.3767869407 kHz

Yup.

@jpivarski
Copy link
Member

Second thing, worth its own comment, is that Numba shouldn't be expected to help optimize a function which is essentially the "slow control" directing "hot loops." Numba is for optimizing just the hot loops.

Therefore, its lists of supported Python features and supported Numpy features concentrate on numerical and array operations. The first bulleted list on the Python language features page says what it doesn't support, and the third bullet is the with statement. Considering the simple nature of that statement, I'd expect it to be supported in the future, but it's not available now. The more dynamic features of Python, such as changing the class of an object after it has been created, will probably never be supported.

My biggest gripe with Numba is its verbose error messages: it took 35 lines and 2600 characters to tell you that it doesn't support the with statement, and then it did so with this cryptic message:

AssertionError: SETUP_WITH(arg=176, lineno=28)

Failed at object (analyzing bytecode)
SETUP_WITH(arg=176, lineno=28)

I've been thinking about contributing some sort of "data analyst-friendly error message mode" to Numba when I get out from under a pile of work.

On your part as a data analyst, it is useful to think about steps in the analysis code in terms of slow control and hot loops, as Numpy, uproot, and most tools in the scientific Python ecosystem have been designed around this philosophy. Similarly, the longtime R users that I've met have the ingrained rule "never write a for loop," specifically never write one over the large dataset. For us, the order parameter is clear: the number of events is usually the only large parameter in the system: anything that loops over events has to be in a compiled function. Anything that doesn't loop over events doesn't need to be and for flexibility probably shouldn't be.

That's why the val.array() call gives you all events— it makes a single Numpy call on each ROOT basket, each of which should contain a large number of events. This is the Numpy call that tops the cProfile.run measurement, sorted by "tottime". Once you've done that, there's no reason to compile the surrounding code because it's small potatoes and usually too complex to compile effectively, anyway.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants