Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[102X][Backport] DeepAK8 tagger integration #24505

Merged
merged 31 commits into from
Oct 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
cdfb9a5
Add MXNet predictor based on the C API.
hqucms Jul 3, 2018
678c743
DataFormats change for DeepAK8.
hqucms Jul 3, 2018
55b6cd7
Add DeepAK8 tagger.
hqucms Jul 3, 2018
6aa6cca
Apply code-checks.
hqucms Jul 4, 2018
fc1fd12
Fix dependency of meta taggers.
hqucms Jul 4, 2018
20b9aec
Fix misuse of edm::Ptr and some clean-ups.
hqucms Jul 5, 2018
d7a1c26
Override clearDaughters in pat::Jet to reset the daughter cache too.
hqucms Jul 5, 2018
d0e4e1a
Some improvements and clean-ups.
hqucms Jul 6, 2018
794b953
Add a test unit for MXNetPredictor.
hqucms Jul 6, 2018
e1e9e1a
Fix DeepBoostedJetTagInfoProducer.
hqucms Jul 6, 2018
70477ac
Need to sort subjets by pt.
hqucms Jul 8, 2018
957cf9d
Thread-safety and better logging in MXNetPredictor.
hqucms Jul 20, 2018
a976d4e
Switch to MXNet C++ API.
hqucms Jul 22, 2018
830e1c3
Protect executor creatation with a mutex in MXNetCppPredictor.
hqucms Aug 8, 2018
0966b24
Refactor MXNetCppPredictor.
hqucms Aug 12, 2018
e8632e0
Remove MXNet C API.
hqucms Aug 13, 2018
03fc196
Move non-TF code in `RecoBTag/TensorFlow` to `RecoBTag/FeatureTools`
hqucms Aug 13, 2018
f13af0c
Code style changes.
hqucms Aug 13, 2018
8997d92
Fix BuildFile.
hqucms Aug 13, 2018
b12fd79
Modify MXNet source to avoid rebinding.
hqucms Aug 13, 2018
d1eaf16
Fix BuildFile.
hqucms Aug 13, 2018
f4562fa
Improve the checks when setting up pfDeepBoostedJetTagInfos.
hqucms Aug 13, 2018
3aae32e
Update DeepBoostedJet meta-taggers.
hqucms Aug 31, 2018
2fcd0fb
Rename MXNet Predictor.
hqucms Aug 31, 2018
fad4dc3
Style and performance improvements.
hqucms Sep 3, 2018
22d3cb5
Mark DeepBoostedJetTagInfo as transient.
hqucms Sep 3, 2018
4e5fb07
Reorganize TagInfo producers and JetTags producers.
hqucms Sep 3, 2018
0f3a195
A few fixes.
hqucms Sep 3, 2018
1f639cf
Improve the setup of DeepBoostedJetTagInfos.
hqucms Sep 7, 2018
32b2b2c
Disable DeepBoostedJet in applyDeepBtagging_cff.py.
hqucms Sep 11, 2018
ed815a3
Revert fix on clearDaughters. Use pf2pc instead.
hqucms Sep 12, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions DataFormats/BTauReco/interface/DeepBoostedJetFeatures.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#ifndef DataFormats_BTauReco_DeepBoostedJetFeatures_h
#define DataFormats_BTauReco_DeepBoostedJetFeatures_h

#include <string>
#include <vector>
#include <unordered_map>
#include "FWCore/Utilities/interface/Exception.h"

namespace btagbtvdeep {

class DeepBoostedJetFeatures {

public:

bool empty() const {
return is_empty_;
}

void add(const std::string& name){
feature_map_[name];
}

void reserve(const std::string& name, unsigned capacity){
feature_map_[name].reserve(capacity);
}

void fill(const std::string& name, float value){
auto item = feature_map_.find(name);
if (item != feature_map_.end()){
item->second.push_back(value);
is_empty_ = false;
}else{
throw cms::Exception("InvalidArgument") << "[DeepBoostedJetFeatures::fill()] Feature " << name << " has not been registered";
}
}

void set(const std::string& name, const std::vector<float>& vec){
feature_map_[name] = vec;
}

void check_consistency(const std::vector<std::string> &names) const {
if (names.empty()) return;
const auto ref_len = get(names.front()).size();
for (unsigned i=1; i<names.size(); ++i){
if (get(names[i]).size() != ref_len){
throw cms::Exception("InvalidArgument") << "[DeepBoostedJetFeatures::check_consistency()] Inconsistent variable length "
<< get(names[i]).size() << " for " << names[i] << ", should be " << ref_len;
}
}
}

const std::vector<float>& get(const std::string& name) const {
auto item = feature_map_.find(name);
if (item != feature_map_.end()){
return item->second;
}else{
throw cms::Exception("InvalidArgument") << "[DeepBoostedJetFeatures::get()] Feature " << name << " does not exist!";
}
}

private:
bool is_empty_ = true;
std::unordered_map<std::string, std::vector<float>> feature_map_;

};

}

#endif // DataFormats_BTauReco_DeepBoostedJetFeatures_h
15 changes: 15 additions & 0 deletions DataFormats/BTauReco/interface/DeepBoostedJetTagInfo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef DataFormats_BTauReco_DeepBoostedJetTagInfo_h
#define DataFormats_BTauReco_DeepBoostedJetTagInfo_h

#include "DataFormats/BTauReco/interface/FeaturesTagInfo.h"
#include "DataFormats/BTauReco/interface/DeepBoostedJetFeatures.h"

namespace reco {

typedef FeaturesTagInfo<btagbtvdeep::DeepBoostedJetFeatures> DeepBoostedJetTagInfo;

DECLARE_EDM_REFS( DeepBoostedJetTagInfo )

}

#endif // DataFormats_BTauReco_DeepBoostedJetTagInfo_h
6 changes: 1 addition & 5 deletions DataFormats/BTauReco/interface/DeepDoubleBTagInfo.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
#ifndef DataFormats_BTauReco_DeepDoubleBTagInfo_h
#define DataFormats_BTauReco_DeepDoubleBTagInfo_h

#include "DataFormats/Common/interface/CMS_CLASS_VERSION.h"
#include "DataFormats/BTauReco/interface/BaseTagInfo.h"
#include "DataFormats/BTauReco/interface/DeepFlavourTagInfo.h"
#include "DataFormats/BTauReco/interface/FeaturesTagInfo.h"
#include "DataFormats/BTauReco/interface/DeepDoubleBFeatures.h"

#include "DataFormats/PatCandidates/interface/Jet.h"

namespace reco {

typedef FeaturesTagInfo<btagbtvdeep::DeepDoubleBFeatures> DeepDoubleBTagInfo;
Expand Down
32 changes: 1 addition & 31 deletions DataFormats/BTauReco/interface/DeepFlavourTagInfo.h
Original file line number Diff line number Diff line change
@@ -1,41 +1,11 @@
#ifndef DataFormats_BTauReco_DeepFlavourTagInfo_h
#define DataFormats_BTauReco_DeepFlavourTagInfo_h

#include "DataFormats/Common/interface/CMS_CLASS_VERSION.h"
#include "DataFormats/BTauReco/interface/BaseTagInfo.h"
#include "DataFormats/BTauReco/interface/DeepFlavourFeatures.h"

#include "DataFormats/PatCandidates/interface/Jet.h"
#include "DataFormats/BTauReco/interface/FeaturesTagInfo.h"

namespace reco {

template<class Features> class FeaturesTagInfo : public BaseTagInfo {

public:

FeaturesTagInfo() {}

FeaturesTagInfo(const Features & features,
const edm::RefToBase<Jet> & jet_ref) :
features_(features),
jet_ref_(jet_ref) {}

edm::RefToBase<Jet> jet() const override { return jet_ref_; }

const Features & features() const { return features_; }

~FeaturesTagInfo() override {}
// without overidding clone from base class will be store/retrieved
FeaturesTagInfo* clone(void) const override { return new FeaturesTagInfo(*this); }


CMS_CLASS_VERSION(3)

private:
Features features_;
edm::RefToBase<Jet> jet_ref_;
};

typedef FeaturesTagInfo<btagbtvdeep::DeepFlavourFeatures> DeepFlavourTagInfo;

DECLARE_EDM_REFS( DeepFlavourTagInfo )
Expand Down
40 changes: 40 additions & 0 deletions DataFormats/BTauReco/interface/FeaturesTagInfo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef DataFormats_BTauReco_FeaturesTagInfo_h
#define DataFormats_BTauReco_FeaturesTagInfo_h

#include "DataFormats/Common/interface/CMS_CLASS_VERSION.h"
#include "DataFormats/BTauReco/interface/BaseTagInfo.h"

#include "DataFormats/PatCandidates/interface/Jet.h"

namespace reco {

template<class Features> class FeaturesTagInfo : public BaseTagInfo {

public:

FeaturesTagInfo() {}

FeaturesTagInfo(const Features & features,
const edm::RefToBase<Jet> & jet_ref) :
features_(features),
jet_ref_(jet_ref) {}

edm::RefToBase<Jet> jet() const override { return jet_ref_; }

const Features & features() const { return features_; }

~FeaturesTagInfo() override {}
// without overidding clone from base class will be store/retrieved
FeaturesTagInfo* clone(void) const override { return new FeaturesTagInfo(*this); }


CMS_CLASS_VERSION(3)

private:
Features features_;
edm::RefToBase<Jet> jet_ref_;
};

}

#endif // DataFormats_BTauReco_FeaturesTagInfo_h
11 changes: 10 additions & 1 deletion DataFormats/BTauReco/src/classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "DataFormats/BTauReco/interface/DeepFlavourTagInfo.h"
#include "DataFormats/BTauReco/interface/DeepDoubleBFeatures.h"
#include "DataFormats/BTauReco/interface/DeepDoubleBTagInfo.h"
#include "DataFormats/BTauReco/interface/DeepBoostedJetTagInfo.h"


namespace reco {
Expand Down Expand Up @@ -289,7 +290,7 @@ namespace DataFormats_BTauReco {
reco::HTTTopJetTagInfoRefVector htttopjet_rv;
edm::Wrapper<reco::HTTTopJetTagInfoCollection> htttopjet_wc;
edm::reftobase::Holder<reco::BaseTagInfo, reco::HTTTopJetTagInfoRef> rb_htttopjet;
edm::reftobase::RefHolder<reco::HTTTopJetTagInfoRef> rbh_htttopjet;
edm::reftobase::RefHolder<reco::HTTTopJetTagInfoRef> rbh_htttopjet;

std::vector<Measurement1D> vm1d;

Expand Down Expand Up @@ -427,5 +428,13 @@ namespace DataFormats_BTauReco {
reco::DeepDoubleBTagInfoRefVector deep_doubleb_tag_info_collection_ref_vector;
edm::Wrapper<reco::DeepDoubleBTagInfoCollection> deep_doubleb_tag_info_collection_edm_wrapper;

btagbtvdeep::DeepBoostedJetFeatures deep_boosted_jet_tag_info_features;
reco::DeepBoostedJetTagInfo deep_boosted_jet_tag_info;
reco::DeepBoostedJetTagInfoCollection deep_boosted_jet_tag_info_collection;
reco::DeepBoostedJetTagInfoRef deep_boosted_jet_tag_info_collection_ref;
reco::DeepBoostedJetTagInfoFwdRef deep_boosted_jet_tag_info_collection_fwd_ref;
reco::DeepBoostedJetTagInfoRefProd deep_boosted_jet_tag_info_collection_ref_prod;
reco::DeepBoostedJetTagInfoRefVector deep_boosted_jet_tag_info_collection_ref_vector;
edm::Wrapper<reco::DeepBoostedJetTagInfoCollection> deep_boosted_jet_tag_info_collection_edm_wrapper;
};
}
9 changes: 9 additions & 0 deletions DataFormats/BTauReco/src/classes_def.xml
Original file line number Diff line number Diff line change
Expand Up @@ -471,4 +471,13 @@
<class name="reco::DeepDoubleBTagInfoRefVector"/>
<class name="edm::Wrapper<reco::DeepDoubleBTagInfoCollection>"/>

<class name="btagbtvdeep::DeepBoostedJetFeatures"/>
<class name="reco::DeepBoostedJetTagInfo"/>
<class name="reco::DeepBoostedJetTagInfoCollection"/>
<class name="reco::DeepBoostedJetTagInfoRef"/>
<class name="reco::DeepBoostedJetTagInfoFwdRef"/>
<class name="reco::DeepBoostedJetTagInfoRefProd"/>
<class name="reco::DeepBoostedJetTagInfoRefVector"/>
<class name="edm::Wrapper<reco::DeepBoostedJetTagInfoCollection>" persistent="false"/>

</lcgdict>
7 changes: 7 additions & 0 deletions PhysicsTools/MXNet/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
<use name="mxnet-predict"/>
<use name="FWCore/Framework"/>
<use name="FWCore/Utilities"/>
<use name="FWCore/MessageLogger"/>
<export>
<lib name="1" />
</export>
88 changes: 88 additions & 0 deletions PhysicsTools/MXNet/interface/Predictor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* MXNetCppPredictor.h
*
* Created on: Jul 19, 2018
* Author: hqu
*/

#ifndef PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_
#define PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_

#include <map>
#include <vector>
#include <memory>
#include <mutex>

#include "mxnet-cpp/MxNetCpp.h"

namespace mxnet {

namespace cpp {

// note: Most of the objects in mxnet::cpp are effective just shared_ptr's

// Simple class to hold MXNet model (symbol + params)
// designed to be sharable by multiple threads
class Block {
public:
Block();
Block(const std::string &symbol_file, const std::string &param_file);
virtual ~Block();

const Symbol& symbol() const { return sym_; }
Symbol symbol(const std::string &output_node) const { return sym_.GetInternals()[output_node]; }
const std::map<std::string, NDArray>& arg_map() const { return arg_map_; }
const std::map<std::string, NDArray>& aux_map() const { return aux_map_; }

private:
void load_parameters(const std::string& param_file);

// symbol
Symbol sym_;
// argument arrays
std::map<std::string, NDArray> arg_map_;
// auxiliary arrays
std::map<std::string, NDArray> aux_map_;
};

// Simple helper class to run prediction
// this cannot be shared between threads
class Predictor {
public:
Predictor();
Predictor(const Block &block);
Predictor(const Block &block, const std::string &output_node);
virtual ~Predictor();

// set input array shapes
void set_input_shapes(const std::vector<std::string>& input_names, const std::vector<std::vector<mx_uint>>& input_shapes);

// run prediction
const std::vector<float>& predict(const std::vector<std::vector<mx_float>>& input_data);

private:
static std::mutex mutex_;

void bind_executor();

// context
static const Context context_;
// executor
std::unique_ptr<Executor> exec_;
// symbol
Symbol sym_;
// argument arrays
std::map<std::string, NDArray> arg_map_;
// auxiliary arrays
std::map<std::string, NDArray> aux_map_;
// output of the prediction
std::vector<float> pred_;
// names of the input nodes
std::vector<std::string> input_names_;

};

} /* namespace cpp */
} /* namespace mxnet */

#endif /* PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_ */
Loading