-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #24505 from hqucms/deep-boosted-jets-rebase-102X
[102X][Backport] DeepAK8 tagger integration
- Loading branch information
Showing
55 changed files
with
2,394 additions
and
336 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#ifndef DataFormats_BTauReco_DeepBoostedJetFeatures_h | ||
#define DataFormats_BTauReco_DeepBoostedJetFeatures_h | ||
|
||
#include <string> | ||
#include <vector> | ||
#include <unordered_map> | ||
#include "FWCore/Utilities/interface/Exception.h" | ||
|
||
namespace btagbtvdeep { | ||
|
||
class DeepBoostedJetFeatures { | ||
|
||
public: | ||
|
||
bool empty() const { | ||
return is_empty_; | ||
} | ||
|
||
void add(const std::string& name){ | ||
feature_map_[name]; | ||
} | ||
|
||
void reserve(const std::string& name, unsigned capacity){ | ||
feature_map_[name].reserve(capacity); | ||
} | ||
|
||
void fill(const std::string& name, float value){ | ||
auto item = feature_map_.find(name); | ||
if (item != feature_map_.end()){ | ||
item->second.push_back(value); | ||
is_empty_ = false; | ||
}else{ | ||
throw cms::Exception("InvalidArgument") << "[DeepBoostedJetFeatures::fill()] Feature " << name << " has not been registered"; | ||
} | ||
} | ||
|
||
void set(const std::string& name, const std::vector<float>& vec){ | ||
feature_map_[name] = vec; | ||
} | ||
|
||
void check_consistency(const std::vector<std::string> &names) const { | ||
if (names.empty()) return; | ||
const auto ref_len = get(names.front()).size(); | ||
for (unsigned i=1; i<names.size(); ++i){ | ||
if (get(names[i]).size() != ref_len){ | ||
throw cms::Exception("InvalidArgument") << "[DeepBoostedJetFeatures::check_consistency()] Inconsistent variable length " | ||
<< get(names[i]).size() << " for " << names[i] << ", should be " << ref_len; | ||
} | ||
} | ||
} | ||
|
||
const std::vector<float>& get(const std::string& name) const { | ||
auto item = feature_map_.find(name); | ||
if (item != feature_map_.end()){ | ||
return item->second; | ||
}else{ | ||
throw cms::Exception("InvalidArgument") << "[DeepBoostedJetFeatures::get()] Feature " << name << " does not exist!"; | ||
} | ||
} | ||
|
||
private: | ||
bool is_empty_ = true; | ||
std::unordered_map<std::string, std::vector<float>> feature_map_; | ||
|
||
}; | ||
|
||
} | ||
|
||
#endif // DataFormats_BTauReco_DeepBoostedJetFeatures_h |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#ifndef DataFormats_BTauReco_DeepBoostedJetTagInfo_h | ||
#define DataFormats_BTauReco_DeepBoostedJetTagInfo_h | ||
|
||
#include "DataFormats/BTauReco/interface/FeaturesTagInfo.h" | ||
#include "DataFormats/BTauReco/interface/DeepBoostedJetFeatures.h" | ||
|
||
namespace reco { | ||
|
||
typedef FeaturesTagInfo<btagbtvdeep::DeepBoostedJetFeatures> DeepBoostedJetTagInfo; | ||
|
||
DECLARE_EDM_REFS( DeepBoostedJetTagInfo ) | ||
|
||
} | ||
|
||
#endif // DataFormats_BTauReco_DeepBoostedJetTagInfo_h |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#ifndef DataFormats_BTauReco_FeaturesTagInfo_h | ||
#define DataFormats_BTauReco_FeaturesTagInfo_h | ||
|
||
#include "DataFormats/Common/interface/CMS_CLASS_VERSION.h" | ||
#include "DataFormats/BTauReco/interface/BaseTagInfo.h" | ||
|
||
#include "DataFormats/PatCandidates/interface/Jet.h" | ||
|
||
namespace reco { | ||
|
||
template<class Features> class FeaturesTagInfo : public BaseTagInfo { | ||
|
||
public: | ||
|
||
FeaturesTagInfo() {} | ||
|
||
FeaturesTagInfo(const Features & features, | ||
const edm::RefToBase<Jet> & jet_ref) : | ||
features_(features), | ||
jet_ref_(jet_ref) {} | ||
|
||
edm::RefToBase<Jet> jet() const override { return jet_ref_; } | ||
|
||
const Features & features() const { return features_; } | ||
|
||
~FeaturesTagInfo() override {} | ||
// without overidding clone from base class will be store/retrieved | ||
FeaturesTagInfo* clone(void) const override { return new FeaturesTagInfo(*this); } | ||
|
||
|
||
CMS_CLASS_VERSION(3) | ||
|
||
private: | ||
Features features_; | ||
edm::RefToBase<Jet> jet_ref_; | ||
}; | ||
|
||
} | ||
|
||
#endif // DataFormats_BTauReco_FeaturesTagInfo_h |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
<use name="mxnet-predict"/> | ||
<use name="FWCore/Framework"/> | ||
<use name="FWCore/Utilities"/> | ||
<use name="FWCore/MessageLogger"/> | ||
<export> | ||
<lib name="1" /> | ||
</export> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
/* | ||
* MXNetCppPredictor.h | ||
* | ||
* Created on: Jul 19, 2018 | ||
* Author: hqu | ||
*/ | ||
|
||
#ifndef PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_ | ||
#define PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_ | ||
|
||
#include <map> | ||
#include <vector> | ||
#include <memory> | ||
#include <mutex> | ||
|
||
#include "mxnet-cpp/MxNetCpp.h" | ||
|
||
namespace mxnet { | ||
|
||
namespace cpp { | ||
|
||
// note: Most of the objects in mxnet::cpp are effective just shared_ptr's | ||
|
||
// Simple class to hold MXNet model (symbol + params) | ||
// designed to be sharable by multiple threads | ||
class Block { | ||
public: | ||
Block(); | ||
Block(const std::string &symbol_file, const std::string ¶m_file); | ||
virtual ~Block(); | ||
|
||
const Symbol& symbol() const { return sym_; } | ||
Symbol symbol(const std::string &output_node) const { return sym_.GetInternals()[output_node]; } | ||
const std::map<std::string, NDArray>& arg_map() const { return arg_map_; } | ||
const std::map<std::string, NDArray>& aux_map() const { return aux_map_; } | ||
|
||
private: | ||
void load_parameters(const std::string& param_file); | ||
|
||
// symbol | ||
Symbol sym_; | ||
// argument arrays | ||
std::map<std::string, NDArray> arg_map_; | ||
// auxiliary arrays | ||
std::map<std::string, NDArray> aux_map_; | ||
}; | ||
|
||
// Simple helper class to run prediction | ||
// this cannot be shared between threads | ||
class Predictor { | ||
public: | ||
Predictor(); | ||
Predictor(const Block &block); | ||
Predictor(const Block &block, const std::string &output_node); | ||
virtual ~Predictor(); | ||
|
||
// set input array shapes | ||
void set_input_shapes(const std::vector<std::string>& input_names, const std::vector<std::vector<mx_uint>>& input_shapes); | ||
|
||
// run prediction | ||
const std::vector<float>& predict(const std::vector<std::vector<mx_float>>& input_data); | ||
|
||
private: | ||
static std::mutex mutex_; | ||
|
||
void bind_executor(); | ||
|
||
// context | ||
static const Context context_; | ||
// executor | ||
std::unique_ptr<Executor> exec_; | ||
// symbol | ||
Symbol sym_; | ||
// argument arrays | ||
std::map<std::string, NDArray> arg_map_; | ||
// auxiliary arrays | ||
std::map<std::string, NDArray> aux_map_; | ||
// output of the prediction | ||
std::vector<float> pred_; | ||
// names of the input nodes | ||
std::vector<std::string> input_names_; | ||
|
||
}; | ||
|
||
} /* namespace cpp */ | ||
} /* namespace mxnet */ | ||
|
||
#endif /* PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_ */ |
Oops, something went wrong.