Skip to content

Commit

Permalink
groot/rnpy: new package to ease ROOT-Tree/NumPy conversion
Browse files Browse the repository at this point in the history
Fixes #903.
  • Loading branch information
sbinet committed Feb 8, 2022
1 parent 37a9ac0 commit be7ddfa
Show file tree
Hide file tree
Showing 8 changed files with 339 additions and 102 deletions.
9 changes: 5 additions & 4 deletions cmd/npy2root/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ import (
"os"

"github.com/apache/arrow/go/arrow/arrio"
"github.com/sbinet/npyio"
"github.com/sbinet/npyio/npy"
"go-hep.org/x/hep/groot"
"go-hep.org/x/hep/groot/rarrow"
"go-hep.org/x/hep/groot/rnpy"
"go-hep.org/x/hep/groot/rtree"
)

Expand Down Expand Up @@ -114,12 +115,12 @@ func process(oname, tname, fname string) error {
}
defer src.Close()

npy, err := npyio.NewReader(src)
npy, err := npy.NewReader(src)
if err != nil {
return fmt.Errorf("could not create numpy file reader %q: %w", fname, err)
}

rec := NewRecord(npy)
rec := rnpy.NewRecord(npy)
defer rec.Release()

dst, err := groot.Create(oname)
Expand All @@ -133,7 +134,7 @@ func process(oname, tname, fname string) error {
return fmt.Errorf("could not create output ROOT tree %q: %w", tname, err)
}

_, err = arrio.Copy(t, NewRecordReader(rec))
_, err = arrio.Copy(t, rnpy.NewRecordReader(rec))
if err != nil {
return fmt.Errorf("could not copy numpy array to ROOT tree %q: %w", tname, err)
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/npy2root/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"reflect"
"testing"

"github.com/sbinet/npyio"
"github.com/sbinet/npyio/npy"
"go-hep.org/x/hep/groot"
"go-hep.org/x/hep/groot/rtree"
)
Expand Down Expand Up @@ -251,7 +251,7 @@ func TestConvert(t *testing.T) {
}
defer src.Close()

err = npyio.Write(src, tc.want)
err = npy.Write(src, tc.want)
if err != nil {
t.Fatalf("could not save NumPy data file: %+v", err)
}
Expand Down
108 changes: 18 additions & 90 deletions cmd/root2npy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,14 @@ import (
"io"
"log"
"os"
"reflect"

"github.com/sbinet/npyio"

"go-hep.org/x/hep/groot"
"go-hep.org/x/hep/groot/riofs"
_ "go-hep.org/x/hep/groot/riofs/plugin/http"
_ "go-hep.org/x/hep/groot/riofs/plugin/xrootd"
"go-hep.org/x/hep/groot/rnpy"
"go-hep.org/x/hep/groot/rtree"
)

Expand Down Expand Up @@ -181,21 +181,7 @@ func process(oname, fname, tname string) error {
return fmt.Errorf("object %q in file %q is not a rtree.Tree", tname, fname)
}

var (
nt = ntuple{n: tree.Entries()}
rvars = rtree.NewReadVars(tree)
)
log.Printf("scanning leaves...")
for _, rvar := range rvars {
rv := reflect.ValueOf(rvar.Value).Elem()
switch rv.Kind() {
case reflect.Struct, reflect.Slice:
log.Printf(">>> %q %T not supported", rvar.Name, rv.Interface())
continue
}
nt.add(rvar)
}
log.Printf("scanning leaves... [done]")
cols := rnpy.NewColumns(tree)

out, err := os.Create(oname)
if err != nil {
Expand All @@ -208,22 +194,31 @@ func process(oname, fname, tname string) error {

wrk := make([]byte, 1*1024*1024)
buf := new(bytes.Buffer)
for i := range nt.cols {
col := &nt.cols[i]
for _, col := range cols {
buf.Reset()
err := col.process(buf, tree)

sli, err := col.Slice()
if err != nil {
return fmt.Errorf("could not read %q: %w", col.Name(), err)
}

err = npyio.Write(buf, sli)
if err != nil {
return fmt.Errorf("could not write %q: %w", col.Name(), err)
}

if err != nil {
return fmt.Errorf("could not process column %q: %w", col.rvar.Name, err)
return fmt.Errorf("could not process column %q: %w", col.Name(), err)
}

wz, err := npz.Create(col.rvar.Name)
wz, err := npz.Create(col.Name())
if err != nil {
return fmt.Errorf("could not create column %q: %w", col.rvar.Name, err)
return fmt.Errorf("could not create column %q: %w", col.Name(), err)
}

_, err = io.CopyBuffer(wz, buf, wrk)
if err != nil {
return fmt.Errorf("could not save column %q: %w", col.rvar.Name, err)
return fmt.Errorf("could not save column %q: %w", col.Name(), err)
}
}

Expand All @@ -244,70 +239,3 @@ func process(oname, fname, tname string) error {

return nil
}

type ntuple struct {
n int64
cols []column
}

func (nt *ntuple) add(rvar rtree.ReadVar) {
nt.cols = append(nt.cols, newColumn(rvar, nt.n))
}

type column struct {
rvar rtree.ReadVar
i int64
etype reflect.Type
shape []int
data reflect.Value
slice reflect.Value
}

func newColumn(rvar rtree.ReadVar, n int64) column {
etype := reflect.TypeOf(rvar.Value).Elem()
shape := []int{int(n)}
rtype := reflect.SliceOf(etype)
return column{
rvar: rvar,
i: 0,
etype: etype,
shape: shape,
data: reflect.ValueOf(rvar.Value).Elem(),
slice: reflect.MakeSlice(rtype, int(n), int(n)),
}
}

func (col *column) process(w io.Writer, t rtree.Tree) error {
defer col.reset()

r, err := rtree.NewReader(t, []rtree.ReadVar{col.rvar})
if err != nil {
return fmt.Errorf(
"could not create ROOT reader for %q: %w",
col.rvar.Name, err,
)
}
defer r.Close()

err = r.Read(func(ctx rtree.RCtx) error {
col.slice.Index(int(col.i)).Set(col.data)
col.i++
return nil
})
if err != nil {
return fmt.Errorf("could not read ROOT data for %q: %w", col.rvar.Name, err)
}

err = npyio.Write(w, col.slice.Interface())
if err != nil {
return fmt.Errorf("could not write %q: %w", col.rvar.Name, err)
}

return nil
}

func (col *column) reset() {
col.i = 0
col.slice = reflect.Zero(col.slice.Type())
col.data = reflect.Zero(col.data.Type())
}
11 changes: 6 additions & 5 deletions cmd/npy2root/arrow.go → groot/rnpy/arrow.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package main
package rnpy

import (
"fmt"
Expand All @@ -14,7 +14,7 @@ import (
"github.com/apache/arrow/go/arrow/array"
"github.com/apache/arrow/go/arrow/arrio"
"github.com/apache/arrow/go/arrow/memory"
"github.com/sbinet/npyio"
"github.com/sbinet/npyio/npy"
)

var (
Expand Down Expand Up @@ -47,7 +47,8 @@ type Record struct {
cols []array.Interface
}

func NewRecord(npy *npyio.Reader) *Record {
// NewRecord returns an Arrow Record from a NumPy data file reader.
func NewRecord(npy *npy.Reader) *Record {
var (
mem = memory.NewGoAllocator()
schema = schemaFrom(npy)
Expand Down Expand Up @@ -125,7 +126,7 @@ func (rec *Record) NewSlice(i, j int64) array.Record {
panic("not implemented")
}

func (rec *Record) read(r *npyio.Reader, nelem int64, bldr array.Builder) {
func (rec *Record) read(r *npy.Reader, nelem int64, bldr array.Builder) {
rt := dtypeFrom(rec.schema.Field(0).Type)
rv := reflect.New(reflect.SliceOf(rt)).Elem()
rv.Set(reflect.MakeSlice(rv.Type(), int(nelem), int(nelem)))
Expand All @@ -150,7 +151,7 @@ func (rec *Record) read(r *npyio.Reader, nelem int64, bldr array.Builder) {
rec.cols = append(rec.cols, bldr.NewArray())
}

func schemaFrom(npy *npyio.Reader) *arrow.Schema {
func schemaFrom(npy *npy.Reader) *arrow.Schema {
var (
hdr = npy.Header
dtype arrow.DataType
Expand Down
2 changes: 1 addition & 1 deletion cmd/npy2root/arrow_test.go → groot/rnpy/arrow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package main
package rnpy

import (
"os"
Expand Down
130 changes: 130 additions & 0 deletions groot/rnpy/column.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Copyright ©2022 The go-hep Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package rnpy

import (
"fmt"
"reflect"

"go-hep.org/x/hep/groot/rtree"
)

// NewColumns returns all the ReadVars of the provided Tree as
// a slice of Columns.
//
// ReadVars that can not be represented as NumPy arrays are silently discarded.
func NewColumns(tree rtree.Tree) []Column {
var (
rvars = rtree.NewReadVars(tree)
cols []Column
)

for _, rvar := range rvars {
rv := reflect.ValueOf(rvar.Value).Elem()
switch rv.Kind() {
case reflect.Chan, reflect.Interface,
reflect.Struct, reflect.Slice, reflect.Map,
reflect.Ptr, reflect.UnsafePointer:
continue
}
cols = append(cols, Column{
tree: tree,
rvar: rvar,
etyp: reflect.TypeOf(rvar.Value).Elem(),
})
}

return cols
}

// Column provides a NumPy representation of a Branch or Leaf.
type Column struct {
tree rtree.Tree
rvar rtree.ReadVar
etyp reflect.Type
}

// NewColumn returns the column with the provided name and tree.
//
// NewColumn returns an error if no branch or leaf could be found.
// NewColumn returns an error if the branch or leaf is of an unsupported type.
func NewColumn(tree rtree.Tree, rvar rtree.ReadVar) (Column, error) {
var (
rvars = rtree.NewReadVars(tree)
idx = -1
col Column
)

for i := range rvars {
if rvars[i].Name == rvar.Name && (rvars[i].Leaf == rvar.Leaf || rvar.Leaf == "") {
idx = i
break
}
}

if idx < 0 {
name := rvar.Name
if rvar.Leaf != "" {
name += "." + rvar.Leaf
}
return col, fmt.Errorf("rnpy: no rvar named %q", name)
}
rvar = rvars[idx]

rv := reflect.ValueOf(rvar.Value).Elem()
switch rv.Kind() {
case reflect.Chan, reflect.Interface,
reflect.Struct, reflect.Slice, reflect.Map,
reflect.Ptr, reflect.UnsafePointer:
return col, fmt.Errorf("rnpy: invalid branch or leaf type %T", rv.Interface())
}

col = Column{
tree: tree,
rvar: rvar,
etyp: reflect.TypeOf(rvar.Value).Elem(),
}
return col, nil
}

// Name returns the branch name this Column is bound to.
func (col Column) Name() string {
return col.rvar.Name
}

// Slice reads the whole data slice from the underlying ROOT Tree
// into memory.
func (col Column) Slice() (sli interface{}, err error) {
r, err := rtree.NewReader(col.tree, []rtree.ReadVar{col.rvar})
if err != nil {
return nil, fmt.Errorf(
"rnpy: could not create ROOT reader for %q: %w",
col.rvar.Name, err,
)
}
defer r.Close()

var (
n = col.tree.Entries()
rtyp = reflect.SliceOf(col.etyp)
data = reflect.ValueOf(col.rvar.Value).Elem()
slice = reflect.MakeSlice(rtyp, int(n), int(n))
i int
)

err = r.Read(func(ctx rtree.RCtx) error {
slice.Index(i).Set(data)
i++
return nil
})
if err != nil {
return nil, fmt.Errorf(
"rnpy: could not read ROOT data for %q: %w",
col.rvar.Name, err,
)
}

return slice.Interface(), nil
}
Loading

0 comments on commit be7ddfa

Please sign in to comment.