Skip to content

Commit

Permalink
Restore TF and MXNet-based inference.
Browse files Browse the repository at this point in the history
And enable ONNXRuntime for x86/arm only.
  • Loading branch information
hqucms committed Mar 10, 2020
1 parent c1ef845 commit d8951bb
Show file tree
Hide file tree
Showing 28 changed files with 971 additions and 520 deletions.
13 changes: 13 additions & 0 deletions PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@
#include <string>
#include <memory>

// currently ONNXRUNTIME only supports x86 and ARM
#if defined(__arm__) || defined(__aarch64__) || defined(__x86_64__) || defined(__i386__)
#define CMS_USE_ONNXRUNTIME
#endif

#ifdef CMS_USE_ONNXRUNTIME
#include "onnxruntime/core/session/onnxruntime_cxx_api.h"
#else
namespace Ort {
struct SessionOptions {};
} // namespace Ort
#endif

namespace cms::Ort {

Expand Down Expand Up @@ -48,6 +59,7 @@ namespace cms::Ort {
// The 0th dim depends on the batch size, therefore is set to -1
const std::vector<int64_t>& getOutputShape(const std::string& output_name) const;

#ifdef CMS_USE_ONNXRUNTIME
private:
static const ::Ort::Env env_;
std::unique_ptr<::Ort::Session> session_;
Expand All @@ -59,6 +71,7 @@ namespace cms::Ort {
std::vector<std::string> output_node_strings_;
std::vector<const char*> output_node_names_;
std::map<std::string, std::vector<int64_t>> output_node_dims_;
#endif
};

} // namespace cms::Ort
Expand Down
18 changes: 17 additions & 1 deletion PhysicsTools/ONNXRuntime/src/ONNXRuntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ namespace cms::Ort {

using namespace ::Ort;

#ifdef CMS_USE_ONNXRUNTIME
const Env ONNXRuntime::env_(ORT_LOGGING_LEVEL_WARNING, "");
#endif

ONNXRuntime::ONNXRuntime(const std::string& model_path, const SessionOptions* session_options) {
#ifdef CMS_USE_ONNXRUNTIME
// create session
if (session_options) {
session_.reset(new Session(env_, model_path.c_str(), *session_options));
Expand Down Expand Up @@ -76,6 +79,7 @@ namespace cms::Ort {
// the 0th dim depends on the batch size
output_node_dims_[output_name].at(0) = -1;
}
#endif
}

ONNXRuntime::~ONNXRuntime() {}
Expand All @@ -84,6 +88,7 @@ namespace cms::Ort {
FloatArrays& input_values,
const std::vector<std::string>& output_names,
int64_t batch_size) const {
#ifdef CMS_USE_ONNXRUNTIME
assert(input_names.size() == input_values.size());
assert(batch_size > 0);

Expand Down Expand Up @@ -142,23 +147,34 @@ namespace cms::Ort {
assert(outputs.size() == run_output_node_names.size());

return outputs;
#else
throw cms::Exception("RuntimeError") << "ONNXRuntime does not support the current architecture";
#endif
}

const std::vector<std::string>& ONNXRuntime::getOutputNames() const {
#ifdef CMS_USE_ONNXRUNTIME
if (session_) {
return output_node_strings_;
} else {
throw cms::Exception("RuntimeError") << "Needs to call createSession() first before getting the output names!";
throw cms::Exception("RuntimeError") << "ONNXRuntime session is not initialized!";
}
#else
throw cms::Exception("RuntimeError") << "ONNXRuntime does not support the current architecture";
#endif
}

const std::vector<int64_t>& ONNXRuntime::getOutputShape(const std::string& output_name) const {
#ifdef CMS_USE_ONNXRUNTIME
auto iter = output_node_dims_.find(output_name);
if (iter == output_node_dims_.end()) {
throw cms::Exception("RuntimeError") << "Output name " << output_name << " is invalid!";
} else {
return iter->second;
}
#else
throw cms::Exception("RuntimeError") << "ONNXRuntime does not support the current architecture";
#endif
}

} /* namespace cms::Ort */
5 changes: 5 additions & 0 deletions PhysicsTools/ONNXRuntime/test/testONNXRuntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "PhysicsTools/ONNXRuntime/interface/ONNXRuntime.h"
#include "FWCore/ParameterSet/interface/FileInPath.h"
#include "FWCore/Utilities/interface/Exception.h"

#include <chrono>
#include <iostream>
Expand All @@ -27,11 +28,15 @@ void testONNXRuntime::checkAll() {
std::vector<float>(batch_size * 2, 1),
};
FloatArrays outputs;
#ifdef CMS_USE_ONNXRUNTIME
CPPUNIT_ASSERT_NO_THROW(outputs = rt.run({"X"}, input_values, {"Y"}, batch_size));
CPPUNIT_ASSERT(outputs.size() == 1);
CPPUNIT_ASSERT(outputs[0].size() == batch_size);
for (const auto &v : outputs[0]) {
CPPUNIT_ASSERT(v == 3);
}
#else
CPPUNIT_ASSERT_THROW(rt.run({"X"}, input_values, {"Y"}, batch_size), cms::Exception);
#endif
}
}
10 changes: 10 additions & 0 deletions PhysicsTools/PatAlgos/python/tools/jetTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def setupBTagging(process, jetSource, pfCandidates, explicitJTA, pvSource, svSou
process.load("RecoBTag.CTagging.cTagging_EventSetup_cff")
import RecoBTag.Configuration.RecoBTag_cff as btag
import RecoJets.JetProducers.caTopTaggers_cff as toptag
from RecoBTag.ONNXRuntime.SwitchProducerONNX import SwitchProducerONNX

if tightBTagNTkHits:
if not runIVF:
Expand Down Expand Up @@ -720,6 +721,15 @@ def setupBTagging(process, jetSource, pfCandidates, explicitJTA, pvSource, svSou
process,
task
)
elif isinstance(getattr(btag, btagDiscr), SwitchProducerONNX):
addToProcessAndTask(
newDiscr,
getattr(btag, btagDiscr).cloneAll(
src = btagPrefix + supportedBtagDiscr[discriminator_name][0][0] + labelName + postfix
),
process,
task
)
else:
raise ValueError('I do not know how to update %s it does not have neither "tagInfos" nor "src" attributes' % btagDiscr)
acceptedBtagDiscriminators.append(discriminator_name)
Expand Down
4 changes: 2 additions & 2 deletions PhysicsTools/TensorFlow/interface/TensorFlow.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ namespace tensorflow {
// return a new session that will contain an already loaded graph def, sessionOptions are predefined
// an error is thrown when graphDef is a nullptr or when the grah has no nodes
// transfers ownership
Session* createSession(GraphDef* graphDef, SessionOptions& sessionOptions);
Session* createSession(const GraphDef* graphDef, SessionOptions& sessionOptions);

// return a new session that will contain an already loaded graph def, threading options are
// inferred from nThreads
// an error is thrown when graphDef is a nullptr or when the grah has no nodes
// transfers ownership
Session* createSession(GraphDef* graphDef, int nThreads = 1);
Session* createSession(const GraphDef* graphDef, int nThreads = 1);

// closes a session, calls its destructor, resets the pointer, and returns true on success
bool closeSession(Session*& session);
Expand Down
4 changes: 2 additions & 2 deletions PhysicsTools/TensorFlow/src/TensorFlow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ namespace tensorflow {
return createSession(metaGraphDef, exportDir, sessionOptions);
}

Session* createSession(GraphDef* graphDef, SessionOptions& sessionOptions) {
Session* createSession(const GraphDef* graphDef, SessionOptions& sessionOptions) {
// check for valid pointer
if (graphDef == nullptr) {
throw cms::Exception("InvalidGraphDef") << "error while creating session: graphDef is nullptr";
Expand All @@ -185,7 +185,7 @@ namespace tensorflow {
return session;
}

Session* createSession(GraphDef* graphDef, int nThreads) {
Session* createSession(const GraphDef* graphDef, int nThreads) {
// create session options and set thread options
SessionOptions sessionOptions;
setThreading(sessionOptions, nThreads);
Expand Down
48 changes: 48 additions & 0 deletions RecoBTag/FeatureTools/interface/tensor_fillers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#ifndef RecoBTag_FeatureTools_tensor_fillers_h
#define RecoBTag_FeatureTools_tensor_fillers_h

#include "DataFormats/BTauReco/interface/DeepFlavourTagInfo.h"
#include "DataFormats/BTauReco/interface/DeepDoubleXTagInfo.h"

namespace btagbtvdeep {

void jet_tensor_filler(float* ptr, const btagbtvdeep::DeepFlavourFeatures& features, unsigned feature_dims);

void jet4vec_tensor_filler(float* ptr, const btagbtvdeep::DeepFlavourFeatures& features, unsigned feature_dims);

void db_tensor_filler(float* ptr, const btagbtvdeep::DeepDoubleXFeatures& features, unsigned feature_dims);

void c_pf_tensor_filler(float* ptr,
std::size_t max_c_pf_n,
const std::vector<btagbtvdeep::ChargedCandidateFeatures>& c_pf_features_vec,
unsigned feature_dims);

void c_pf_reduced_tensor_filler(float* ptr,
std::size_t max_c_pf_n,
const std::vector<btagbtvdeep::ChargedCandidateFeatures>& c_pf_features_vec,
unsigned feature_dims);

void n_pf_tensor_filler(float* ptr,
std::size_t max_n_pf_n,
const std::vector<btagbtvdeep::NeutralCandidateFeatures>& n_pf_features_vec,
unsigned feature_dims);

void sv_tensor_filler(float* ptr,
std::size_t max_sv_n,
const std::vector<btagbtvdeep::SecondaryVertexFeatures>& sv_features_vec,
unsigned feature_dims);

void sv_reduced_tensor_filler(float* ptr,
std::size_t max_sv_n,
const std::vector<btagbtvdeep::SecondaryVertexFeatures>& sv_features_vec,
unsigned feature_dims);

void seed_tensor_filler(float* ptr, const btagbtvdeep::SeedingTrackFeatures& seed_features, unsigned feature_dims);

void neighbourTracks_tensor_filler(float* ptr,
const btagbtvdeep::SeedingTrackFeatures& seed_features,
unsigned feature_dims);

} // namespace btagbtvdeep

#endif
Loading

0 comments on commit d8951bb

Please sign in to comment.