Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[10_6_X] ONNXRuntime-based implementation of DeepJet, DeepAK8 and DeepDoubleX #30123

Merged
merged 4 commits into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions PhysicsTools/NanoAOD/python/nano_cff.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,6 @@ def nanoAOD_addDeepInfo(process,addDeepBTag,addDeepFlavour):
process.load("Configuration.StandardSequences.MagneticField_cff")
process.jetCorrFactorsNano.src="selectedUpdatedPatJetsWithDeepInfo"
process.updatedJets.jetSource="selectedUpdatedPatJetsWithDeepInfo"
if addDeepFlavour:
process.pfDeepFlavourJetTagsWithDeepInfo.graph_path = 'RecoBTag/Combined/data/DeepFlavourV03_10X_training/constant_graph.pb'
process.pfDeepFlavourJetTagsWithDeepInfo.lp_names = ["cpf_input_batchnorm/keras_learning_phase"]
return process

from PhysicsTools.PatUtils.tools.runMETCorrectionsAndUncertainties import runMetCorAndUncFromMiniAOD
Expand Down Expand Up @@ -237,7 +234,7 @@ def nanoAOD_addDeepInfoAK8(process,addDeepBTag,addDeepBoostedJet, addDeepDoubleX
_btagDiscriminators += ['pfDeepCSVJetTags:probb','pfDeepCSVJetTags:probbb']
if addDeepBoostedJet:
print("Updating process to run DeepBoostedJet on datasets before 103X")
from RecoBTag.MXNet.pfDeepBoostedJet_cff import _pfDeepBoostedJetTagsAll as pfDeepBoostedJetTagsAll
from RecoBTag.ONNXRuntime.pfDeepBoostedJet_cff import _pfDeepBoostedJetTagsAll as pfDeepBoostedJetTagsAll
_btagDiscriminators += pfDeepBoostedJetTagsAll
if addDeepDoubleX:
print("Updating process to run DeepDoubleX on datasets before 104X")
Expand Down
5 changes: 5 additions & 0 deletions PhysicsTools/ONNXRuntime/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<use name="onnxruntime"/>
<use name="FWCore/Utilities"/>
<export>
<lib name="1"/>
</export>
66 changes: 66 additions & 0 deletions PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* ONNXRuntime.h
*
* A convenience wrapper of the ONNXRuntime C++ API.
* Based on https://github.com/microsoft/onnxruntime/blob/master/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/CXX_Api_Sample.cpp.
*
* Created on: Jun 28, 2019
* Author: hqu
*/

#ifndef PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_
#define PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_

#include <vector>
#include <map>
#include <string>
#include <memory>

#include "onnxruntime/core/session/onnxruntime_cxx_api.h"

namespace cms::Ort {

typedef std::vector<std::vector<float>> FloatArrays;

class ONNXRuntime {
public:
ONNXRuntime(const std::string& model_path, const ::Ort::SessionOptions* session_options = nullptr);
ONNXRuntime(const ONNXRuntime&) = delete;
ONNXRuntime& operator=(const ONNXRuntime&) = delete;
~ONNXRuntime();

// Run inference and get outputs
// input_names: list of the names of the input nodes.
// input_values: list of input arrays for each input node. The order of `input_values` must match `input_names`.
// output_names: names of the output nodes to get outputs from. Empty list means all output nodes.
// batch_size: number of samples in the batch. Each array in `input_values` must have a shape layout of (batch_size, ...).
// Returns: a std::vector<std::vector<float>>, with the order matched to `output_names`.
// When `output_names` is empty, will return all outputs ordered as in `getOutputNames()`.
FloatArrays run(const std::vector<std::string>& input_names,
FloatArrays& input_values,
const std::vector<std::string>& output_names = {},
int64_t batch_size = 1) const;

// Get a list of names of all the output nodes
const std::vector<std::string>& getOutputNames() const;

// Get the shape of a output node
// The 0th dim depends on the batch size, therefore is set to -1
const std::vector<int64_t>& getOutputShape(const std::string& output_name) const;

private:
static const ::Ort::Env env_;
std::unique_ptr<::Ort::Session> session_;

std::vector<std::string> input_node_strings_;
std::vector<const char*> input_node_names_;
std::map<std::string, std::vector<int64_t>> input_node_dims_;

std::vector<std::string> output_node_strings_;
std::vector<const char*> output_node_names_;
std::map<std::string, std::vector<int64_t>> output_node_dims_;
};

} // namespace cms::Ort

#endif /* PHYSICSTOOLS_ONNXRUNTIME_INTERFACE_ONNXRUNTIME_H_ */
164 changes: 164 additions & 0 deletions PhysicsTools/ONNXRuntime/src/ONNXRuntime.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* ONNXRuntime.cc
*
* Created on: Jun 28, 2019
* Author: hqu
*/

#include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"

#include <cassert>
#include <iostream>
#include <algorithm>
#include <numeric>
#include <functional>
#include "FWCore/Utilities/interface/Exception.h"
#include "FWCore/Utilities/interface/thread_safety_macros.h"

namespace cms::Ort {

using namespace ::Ort;

const Env ONNXRuntime::env_(ORT_LOGGING_LEVEL_ERROR, "");

ONNXRuntime::ONNXRuntime(const std::string& model_path, const SessionOptions* session_options) {
// create session
if (session_options) {
session_.reset(new Session(env_, model_path.c_str(), *session_options));
} else {
SessionOptions sess_opts;
sess_opts.SetIntraOpNumThreads(1);
session_.reset(new Session(env_, model_path.c_str(), sess_opts));
}
AllocatorWithDefaultOptions allocator;

// get input names and shapes
size_t num_input_nodes = session_->GetInputCount();
input_node_strings_.resize(num_input_nodes);
input_node_names_.resize(num_input_nodes);
input_node_dims_.clear();

for (size_t i = 0; i < num_input_nodes; i++) {
// get input node names
std::string input_name(session_->GetInputName(i, allocator));
input_node_strings_[i] = input_name;
input_node_names_[i] = input_node_strings_[i].c_str();

// get input shapes
auto type_info = session_->GetInputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
size_t num_dims = tensor_info.GetDimensionsCount();
input_node_dims_[input_name].resize(num_dims);
tensor_info.GetDimensions(input_node_dims_[input_name].data(), num_dims);

// set the batch size to 1 by default
input_node_dims_[input_name].at(0) = 1;
}

size_t num_output_nodes = session_->GetOutputCount();
output_node_strings_.resize(num_output_nodes);
output_node_names_.resize(num_output_nodes);
output_node_dims_.clear();

for (size_t i = 0; i < num_output_nodes; i++) {
// get output node names
std::string output_name(session_->GetOutputName(i, allocator));
output_node_strings_[i] = output_name;
output_node_names_[i] = output_node_strings_[i].c_str();

// get output node types
auto type_info = session_->GetOutputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
size_t num_dims = tensor_info.GetDimensionsCount();
output_node_dims_[output_name].resize(num_dims);
tensor_info.GetDimensions(output_node_dims_[output_name].data(), num_dims);

// the 0th dim depends on the batch size
output_node_dims_[output_name].at(0) = -1;
}
}

ONNXRuntime::~ONNXRuntime() {}

FloatArrays ONNXRuntime::run(const std::vector<std::string>& input_names,
FloatArrays& input_values,
const std::vector<std::string>& output_names,
int64_t batch_size) const {
assert(input_names.size() == input_values.size());
assert(batch_size > 0);

// create input tensor objects from data values
std::vector<Value> input_tensors;
auto memory_info = MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
for (const auto& name : input_node_strings_) {
auto iter = std::find(input_names.begin(), input_names.end(), name);
if (iter == input_names.end()) {
throw cms::Exception("RuntimeError") << "Input " << name << " is not provided!";
}
auto value = input_values.begin() + (iter - input_names.begin());
auto input_dims = input_node_dims_.at(name);
input_dims[0] = batch_size;
auto expected_len = std::accumulate(input_dims.begin(), input_dims.end(), 1, std::multiplies<int64_t>());
if (expected_len != (int64_t)value->size()) {
throw cms::Exception("RuntimeError")
<< "Input array " << name << " has a wrong size of " << value->size() << ", expected " << expected_len;
}
auto input_tensor =
Value::CreateTensor<float>(memory_info, value->data(), value->size(), input_dims.data(), input_dims.size());
assert(input_tensor.IsTensor());
input_tensors.emplace_back(std::move(input_tensor));
}

// set output node names; will get all outputs if `output_names` is not provided
std::vector<const char*> run_output_node_names;
if (output_names.empty()) {
run_output_node_names = output_node_names_;
} else {
for (const auto& name : output_names) {
run_output_node_names.push_back(name.c_str());
}
}

// run
auto output_tensors = session_->Run(RunOptions{nullptr},
input_node_names_.data(),
input_tensors.data(),
input_tensors.size(),
run_output_node_names.data(),
run_output_node_names.size());

// convert output to floats
FloatArrays outputs;
for (auto& output_tensor : output_tensors) {
assert(output_tensor.IsTensor());

// get output shape
auto tensor_info = output_tensor.GetTensorTypeAndShapeInfo();
auto length = tensor_info.GetElementCount();

auto floatarr = output_tensor.GetTensorMutableData<float>();
outputs.emplace_back(floatarr, floatarr + length);
}
assert(outputs.size() == run_output_node_names.size());

return outputs;
}

const std::vector<std::string>& ONNXRuntime::getOutputNames() const {
if (session_) {
return output_node_strings_;
} else {
throw cms::Exception("RuntimeError") << "Needs to call createSession() first before getting the output names!";
}
}

const std::vector<int64_t>& ONNXRuntime::getOutputShape(const std::string& output_name) const {
auto iter = output_node_dims_.find(output_name);
if (iter == output_node_dims_.end()) {
throw cms::Exception("RuntimeError") << "Output name " << output_name << " is invalid!";
} else {
return iter->second;
}
}

} /* namespace cms::Ort */
7 changes: 7 additions & 0 deletions PhysicsTools/ONNXRuntime/test/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
<bin name="testONNXRuntime" file="testRunner.cpp, testONNXRuntime.cc">
<use name="boost_filesystem"/>
<use name="cppunit"/>
<use name="PhysicsTools/ONNXRuntime"/>
<use name="FWCore/ParameterSet"/>
<use name="FWCore/Utilities"/>
</bin>
Binary file added PhysicsTools/ONNXRuntime/test/data/model.onnx
Binary file not shown.
37 changes: 37 additions & 0 deletions PhysicsTools/ONNXRuntime/test/testONNXRuntime.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include <cppunit/extensions/HelperMacros.h>

#include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"
#include "FWCore/ParameterSet/interface/FileInPath.h"

#include <chrono>
#include <iostream>

using namespace cms::Ort;

class testONNXRuntime : public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE(testONNXRuntime);
CPPUNIT_TEST(checkAll);
CPPUNIT_TEST_SUITE_END();

public:
void checkAll();
};

CPPUNIT_TEST_SUITE_REGISTRATION(testONNXRuntime);

void testONNXRuntime::checkAll() {
std::string model_path = edm::FileInPath("PhysicsTools/ONNXRuntime/test/data/model.onnx").fullPath();
ONNXRuntime rt(model_path);
for (const unsigned batch_size : {1, 2, 4}) {
FloatArrays input_values{
std::vector<float>(batch_size * 2, 1),
};
FloatArrays outputs;
CPPUNIT_ASSERT_NO_THROW(outputs = rt.run({"X"}, input_values, {"Y"}, batch_size));
CPPUNIT_ASSERT(outputs.size() == 1);
CPPUNIT_ASSERT(outputs[0].size() == batch_size);
for (const auto &v : outputs[0]) {
CPPUNIT_ASSERT(v == 3);
}
}
}
1 change: 1 addition & 0 deletions PhysicsTools/ONNXRuntime/test/testRunner.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#include <Utilities/Testing/interface/CppUnit_testdriver.icpp>
2 changes: 1 addition & 1 deletion PhysicsTools/PatAlgos/python/recoLayer0/bTagging_cff.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@

# -----------------------------------
# setup DeepBoostedJet
from RecoBTag.MXNet.pfDeepBoostedJet_cff import _pfDeepBoostedJetTagsProbs, _pfDeepBoostedJetTagsMetaDiscrs, \
from RecoBTag.ONNXRuntime.pfDeepBoostedJet_cff import _pfDeepBoostedJetTagsProbs, _pfDeepBoostedJetTagsMetaDiscrs, \
_pfMassDecorrelatedDeepBoostedJetTagsProbs, _pfMassDecorrelatedDeepBoostedJetTagsMetaDiscrs
# update supportedBtagDiscr
for disc in _pfDeepBoostedJetTagsProbs + _pfMassDecorrelatedDeepBoostedJetTagsProbs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def applyDeepBtagging( process, postfix="" ) :
# delete module not used anymore (slimmedJets substitutes)
delattr(process, 'selectedUpdatedPatJetsSlimmedDeepFlavour'+postfix)

from RecoBTag.MXNet.pfDeepBoostedJet_cff import _pfDeepBoostedJetTagsAll as pfDeepBoostedJetTagsAll
from RecoBTag.ONNXRuntime.pfDeepBoostedJet_cff import _pfDeepBoostedJetTagsAll as pfDeepBoostedJetTagsAll
from RecoBTag.MXNet.pfParticleNet_cff import _pfParticleNetJetTagsAll as pfParticleNetJetTagsAll

# update slimmed jets to include particle-based deep taggers (keep same name)
Expand Down
6 changes: 3 additions & 3 deletions RecoBTag/Configuration/python/RecoBTag_cff.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from RecoBTag.Combined.combinedMVA_cff import *
from RecoBTag.CTagging.RecoCTagging_cff import *
from RecoBTag.Combined.deepFlavour_cff import *
from RecoBTag.TensorFlow.pfDeepFlavour_cff import *
from RecoBTag.TensorFlow.pfDeepDoubleX_cff import *
from RecoBTag.MXNet.pfDeepBoostedJet_cff import *
from RecoBTag.ONNXRuntime.pfDeepFlavour_cff import *
from RecoBTag.ONNXRuntime.pfDeepDoubleX_cff import *
from RecoBTag.ONNXRuntime.pfDeepBoostedJet_cff import *
from RecoBTag.MXNet.pfParticleNet_cff import *
from RecoVertex.AdaptiveVertexFinder.inclusiveVertexing_cff import *

Expand Down
39 changes: 13 additions & 26 deletions RecoBTag/MXNet/plugins/BoostedJetMXNetJetTagsProducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,36 +141,23 @@ BoostedJetMXNetJetTagsProducer::BoostedJetMXNetJetTagsProducer(const edm::Parame
BoostedJetMXNetJetTagsProducer::~BoostedJetMXNetJetTagsProducer() {}

void BoostedJetMXNetJetTagsProducer::fillDescriptions(edm::ConfigurationDescriptions &descriptions) {
// pfDeepBoostedJetTags
// pfParticleNetJetTags
edm::ParameterSetDescription desc;
desc.add<edm::InputTag>("src", edm::InputTag("pfDeepBoostedJetTagInfos"));
desc.add<edm::InputTag>("src", edm::InputTag("pfParticleNetTagInfos"));
edm::ParameterSetDescription preprocessParams;
preprocessParams.setAllowAnything();
desc.add<edm::ParameterSetDescription>("preprocessParams", preprocessParams);
desc.add<edm::FileInPath>("model_path",
edm::FileInPath("RecoBTag/Combined/data/DeepBoostedJet/V01/full/resnet-symbol.json"));
desc.add<edm::FileInPath>("param_path",
edm::FileInPath("RecoBTag/Combined/data/DeepBoostedJet/V01/full/resnet-0000.params"));
desc.add<std::vector<std::string>>("flav_names",
std::vector<std::string>{
"probTbcq",
"probTbqq",
"probTbc",
"probTbq",
"probWcq",
"probWqq",
"probZbb",
"probZcc",
"probZqq",
"probHbb",
"probHcc",
"probHqqqq",
"probQCDbb",
"probQCDcc",
"probQCDb",
"probQCDc",
"probQCDothers",
});
desc.add<edm::FileInPath>(
"model_path", edm::FileInPath("RecoBTag/Combined/data/ParticleNetAK8/General/V00/ParticleNet-symbol.json"));
desc.add<edm::FileInPath>(
"param_path", edm::FileInPath("RecoBTag/Combined/data/ParticleNetAK8/General/V00/ParticleNet-0000.params"));
desc.add<std::vector<std::string>>(
"flav_names",
std::vector<std::string>{
"probTbcq", "probTbqq", "probTbc", "probTbq", "probTbel", "probTbmu", "probTbta",
"probWcq", "probWqq", "probZbb", "probZcc", "probZqq", "probHbb", "probHcc",
"probHqqqq", "probQCDbb", "probQCDcc", "probQCDb", "probQCDc", "probQCDothers",
});
desc.addOptionalUntracked<bool>("debugMode", false);

descriptions.addWithDefaultLabel(desc);
Expand Down
1 change: 0 additions & 1 deletion RecoBTag/MXNet/plugins/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
<use name="FWCore/Framework"/>
<use name="FWCore/MessageLogger"/>
<library file="*.cc" name="RecoBTagMXNetPlugins">
<use name="DataFormats/BTauReco"/>
<use name="PhysicsTools/MXNet"/>
Expand Down
Loading