From aa464ae254341e281ff2a6496d186a7ff6c2649f Mon Sep 17 00:00:00 2001 From: Marsel Mavletkulov Date: Mon, 29 Jul 2024 14:12:22 -0400 Subject: [PATCH 1/2] Save 48 bytes per Node --- parse.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/parse.go b/parse.go index 1135336..8926063 100644 --- a/parse.go +++ b/parse.go @@ -62,19 +62,19 @@ type Tree struct { // Node is a node in the Tree. type Node struct { - Data NodeData Left *Node Right *Node + Data NodeData } // NodeData is a Node's data. type NodeData struct { - BaseWeight float32 - DefaultLeft bool ID int - SplitCondition float32 SplitIndex int + SplitCondition float32 SumHessian float32 + BaseWeight float32 + DefaultLeft bool } // IsLeaf returns whether the Node is a leaf. From 7d5765edc0d81c93b0aa666c7a347df99fff16f2 Mon Sep 17 00:00:00 2001 From: Marsel Mavletkulov Date: Mon, 29 Jul 2024 15:34:19 -0400 Subject: [PATCH 2/2] Reduce parseModel() allocs by 29.8% --- contributions.go | 2 +- parse.go | 12 ++++-------- parse_test.go | 14 ++++++++++++++ 3 files changed, 19 insertions(+), 9 deletions(-) create mode 100644 parse_test.go diff --git a/contributions.go b/contributions.go index ff86c01..b7010a4 100644 --- a/contributions.go +++ b/contributions.go @@ -280,7 +280,7 @@ func treeShap( isMissing := features[splitIndex] == nil // nil means missing. hotIndex := getNextNode( hasMissing, - node, + &node, nodeIndex, features[splitIndex], isMissing, diff --git a/parse.go b/parse.go index 8926063..626065f 100644 --- a/parse.go +++ b/parse.go @@ -56,7 +56,7 @@ type TreeParam struct { // Tree is one tree in an XGBoost model. It's the representation we process // XGBTree into. type Tree struct { - Nodes []*Node // Index 0 is the root. + Nodes []Node // Index 0 is the root. NumNodes int } @@ -136,11 +136,7 @@ func parseTree( return nil, fmt.Errorf("getting num nodes as int64: %w", err) } - var nodes []*Node - for i := 0; i < int(numNodes); i++ { - nodes = append(nodes, &Node{}) - } - + nodes := make([]Node, numNodes) for i := 0; i < int(numNodes); i++ { nodes[i].Data = NodeData{ BaseWeight: xt.BaseWeights[i], @@ -158,8 +154,8 @@ func parseTree( continue } - nodes[i].Left = nodes[left] - nodes[i].Right = nodes[right] + nodes[i].Left = &nodes[left] + nodes[i].Right = &nodes[right] } return &Tree{ diff --git a/parse_test.go b/parse_test.go new file mode 100644 index 0000000..69f2ae8 --- /dev/null +++ b/parse_test.go @@ -0,0 +1,14 @@ +package xgbshap + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func BenchmarkParseModel(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _, err := parseModel("testdata/small-model/model.json") + require.NoError(b, err) + } +}