Skip to content

Commit

Permalink
Merge pull request #24505 from hqucms/deep-boosted-jets-rebase-102X
Browse files Browse the repository at this point in the history
[102X][Backport] DeepAK8 tagger integration
  • Loading branch information
cmsbuild authored Oct 3, 2018
2 parents 6574a0b + ed815a3 commit fe0f9a9
Show file tree
Hide file tree
Showing 55 changed files with 2,394 additions and 336 deletions.
69 changes: 69 additions & 0 deletions DataFormats/BTauReco/interface/DeepBoostedJetFeatures.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#ifndef DataFormats_BTauReco_DeepBoostedJetFeatures_h
#define DataFormats_BTauReco_DeepBoostedJetFeatures_h

#include <string>
#include <vector>
#include <unordered_map>
#include "FWCore/Utilities/interface/Exception.h"

namespace btagbtvdeep {

class DeepBoostedJetFeatures {

public:

bool empty() const {
return is_empty_;
}

void add(const std::string& name){
feature_map_[name];
}

void reserve(const std::string& name, unsigned capacity){
feature_map_[name].reserve(capacity);
}

void fill(const std::string& name, float value){
auto item = feature_map_.find(name);
if (item != feature_map_.end()){
item->second.push_back(value);
is_empty_ = false;
}else{
throw cms::Exception("InvalidArgument") << "[DeepBoostedJetFeatures::fill()] Feature " << name << " has not been registered";
}
}

void set(const std::string& name, const std::vector<float>& vec){
feature_map_[name] = vec;
}

void check_consistency(const std::vector<std::string> &names) const {
if (names.empty()) return;
const auto ref_len = get(names.front()).size();
for (unsigned i=1; i<names.size(); ++i){
if (get(names[i]).size() != ref_len){
throw cms::Exception("InvalidArgument") << "[DeepBoostedJetFeatures::check_consistency()] Inconsistent variable length "
<< get(names[i]).size() << " for " << names[i] << ", should be " << ref_len;
}
}
}

const std::vector<float>& get(const std::string& name) const {
auto item = feature_map_.find(name);
if (item != feature_map_.end()){
return item->second;
}else{
throw cms::Exception("InvalidArgument") << "[DeepBoostedJetFeatures::get()] Feature " << name << " does not exist!";
}
}

private:
bool is_empty_ = true;
std::unordered_map<std::string, std::vector<float>> feature_map_;

};

}

#endif // DataFormats_BTauReco_DeepBoostedJetFeatures_h
15 changes: 15 additions & 0 deletions DataFormats/BTauReco/interface/DeepBoostedJetTagInfo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef DataFormats_BTauReco_DeepBoostedJetTagInfo_h
#define DataFormats_BTauReco_DeepBoostedJetTagInfo_h

#include "DataFormats/BTauReco/interface/FeaturesTagInfo.h"
#include "DataFormats/BTauReco/interface/DeepBoostedJetFeatures.h"

namespace reco {

typedef FeaturesTagInfo<btagbtvdeep::DeepBoostedJetFeatures> DeepBoostedJetTagInfo;

DECLARE_EDM_REFS( DeepBoostedJetTagInfo )

}

#endif // DataFormats_BTauReco_DeepBoostedJetTagInfo_h
6 changes: 1 addition & 5 deletions DataFormats/BTauReco/interface/DeepDoubleBTagInfo.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
#ifndef DataFormats_BTauReco_DeepDoubleBTagInfo_h
#define DataFormats_BTauReco_DeepDoubleBTagInfo_h

#include "DataFormats/Common/interface/CMS_CLASS_VERSION.h"
#include "DataFormats/BTauReco/interface/BaseTagInfo.h"
#include "DataFormats/BTauReco/interface/DeepFlavourTagInfo.h"
#include "DataFormats/BTauReco/interface/FeaturesTagInfo.h"
#include "DataFormats/BTauReco/interface/DeepDoubleBFeatures.h"

#include "DataFormats/PatCandidates/interface/Jet.h"

namespace reco {

typedef FeaturesTagInfo<btagbtvdeep::DeepDoubleBFeatures> DeepDoubleBTagInfo;
Expand Down
32 changes: 1 addition & 31 deletions DataFormats/BTauReco/interface/DeepFlavourTagInfo.h
Original file line number Diff line number Diff line change
@@ -1,41 +1,11 @@
#ifndef DataFormats_BTauReco_DeepFlavourTagInfo_h
#define DataFormats_BTauReco_DeepFlavourTagInfo_h

#include "DataFormats/Common/interface/CMS_CLASS_VERSION.h"
#include "DataFormats/BTauReco/interface/BaseTagInfo.h"
#include "DataFormats/BTauReco/interface/DeepFlavourFeatures.h"

#include "DataFormats/PatCandidates/interface/Jet.h"
#include "DataFormats/BTauReco/interface/FeaturesTagInfo.h"

namespace reco {

template<class Features> class FeaturesTagInfo : public BaseTagInfo {

public:

FeaturesTagInfo() {}

FeaturesTagInfo(const Features & features,
const edm::RefToBase<Jet> & jet_ref) :
features_(features),
jet_ref_(jet_ref) {}

edm::RefToBase<Jet> jet() const override { return jet_ref_; }

const Features & features() const { return features_; }

~FeaturesTagInfo() override {}
// without overidding clone from base class will be store/retrieved
FeaturesTagInfo* clone(void) const override { return new FeaturesTagInfo(*this); }


CMS_CLASS_VERSION(3)

private:
Features features_;
edm::RefToBase<Jet> jet_ref_;
};

typedef FeaturesTagInfo<btagbtvdeep::DeepFlavourFeatures> DeepFlavourTagInfo;

DECLARE_EDM_REFS( DeepFlavourTagInfo )
Expand Down
40 changes: 40 additions & 0 deletions DataFormats/BTauReco/interface/FeaturesTagInfo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef DataFormats_BTauReco_FeaturesTagInfo_h
#define DataFormats_BTauReco_FeaturesTagInfo_h

#include "DataFormats/Common/interface/CMS_CLASS_VERSION.h"
#include "DataFormats/BTauReco/interface/BaseTagInfo.h"

#include "DataFormats/PatCandidates/interface/Jet.h"

namespace reco {

template<class Features> class FeaturesTagInfo : public BaseTagInfo {

public:

FeaturesTagInfo() {}

FeaturesTagInfo(const Features & features,
const edm::RefToBase<Jet> & jet_ref) :
features_(features),
jet_ref_(jet_ref) {}

edm::RefToBase<Jet> jet() const override { return jet_ref_; }

const Features & features() const { return features_; }

~FeaturesTagInfo() override {}
// without overidding clone from base class will be store/retrieved
FeaturesTagInfo* clone(void) const override { return new FeaturesTagInfo(*this); }


CMS_CLASS_VERSION(3)

private:
Features features_;
edm::RefToBase<Jet> jet_ref_;
};

}

#endif // DataFormats_BTauReco_FeaturesTagInfo_h
11 changes: 10 additions & 1 deletion DataFormats/BTauReco/src/classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "DataFormats/BTauReco/interface/DeepFlavourTagInfo.h"
#include "DataFormats/BTauReco/interface/DeepDoubleBFeatures.h"
#include "DataFormats/BTauReco/interface/DeepDoubleBTagInfo.h"
#include "DataFormats/BTauReco/interface/DeepBoostedJetTagInfo.h"


namespace reco {
Expand Down Expand Up @@ -289,7 +290,7 @@ namespace DataFormats_BTauReco {
reco::HTTTopJetTagInfoRefVector htttopjet_rv;
edm::Wrapper<reco::HTTTopJetTagInfoCollection> htttopjet_wc;
edm::reftobase::Holder<reco::BaseTagInfo, reco::HTTTopJetTagInfoRef> rb_htttopjet;
edm::reftobase::RefHolder<reco::HTTTopJetTagInfoRef> rbh_htttopjet;
edm::reftobase::RefHolder<reco::HTTTopJetTagInfoRef> rbh_htttopjet;

std::vector<Measurement1D> vm1d;

Expand Down Expand Up @@ -427,5 +428,13 @@ namespace DataFormats_BTauReco {
reco::DeepDoubleBTagInfoRefVector deep_doubleb_tag_info_collection_ref_vector;
edm::Wrapper<reco::DeepDoubleBTagInfoCollection> deep_doubleb_tag_info_collection_edm_wrapper;

btagbtvdeep::DeepBoostedJetFeatures deep_boosted_jet_tag_info_features;
reco::DeepBoostedJetTagInfo deep_boosted_jet_tag_info;
reco::DeepBoostedJetTagInfoCollection deep_boosted_jet_tag_info_collection;
reco::DeepBoostedJetTagInfoRef deep_boosted_jet_tag_info_collection_ref;
reco::DeepBoostedJetTagInfoFwdRef deep_boosted_jet_tag_info_collection_fwd_ref;
reco::DeepBoostedJetTagInfoRefProd deep_boosted_jet_tag_info_collection_ref_prod;
reco::DeepBoostedJetTagInfoRefVector deep_boosted_jet_tag_info_collection_ref_vector;
edm::Wrapper<reco::DeepBoostedJetTagInfoCollection> deep_boosted_jet_tag_info_collection_edm_wrapper;
};
}
9 changes: 9 additions & 0 deletions DataFormats/BTauReco/src/classes_def.xml
Original file line number Diff line number Diff line change
Expand Up @@ -471,4 +471,13 @@
<class name="reco::DeepDoubleBTagInfoRefVector"/>
<class name="edm::Wrapper<reco::DeepDoubleBTagInfoCollection>"/>

<class name="btagbtvdeep::DeepBoostedJetFeatures"/>
<class name="reco::DeepBoostedJetTagInfo"/>
<class name="reco::DeepBoostedJetTagInfoCollection"/>
<class name="reco::DeepBoostedJetTagInfoRef"/>
<class name="reco::DeepBoostedJetTagInfoFwdRef"/>
<class name="reco::DeepBoostedJetTagInfoRefProd"/>
<class name="reco::DeepBoostedJetTagInfoRefVector"/>
<class name="edm::Wrapper<reco::DeepBoostedJetTagInfoCollection>" persistent="false"/>

</lcgdict>
7 changes: 7 additions & 0 deletions PhysicsTools/MXNet/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
<use name="mxnet-predict"/>
<use name="FWCore/Framework"/>
<use name="FWCore/Utilities"/>
<use name="FWCore/MessageLogger"/>
<export>
<lib name="1" />
</export>
88 changes: 88 additions & 0 deletions PhysicsTools/MXNet/interface/Predictor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* MXNetCppPredictor.h
*
* Created on: Jul 19, 2018
* Author: hqu
*/

#ifndef PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_
#define PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_

#include <map>
#include <vector>
#include <memory>
#include <mutex>

#include "mxnet-cpp/MxNetCpp.h"

namespace mxnet {

namespace cpp {

// note: Most of the objects in mxnet::cpp are effective just shared_ptr's

// Simple class to hold MXNet model (symbol + params)
// designed to be sharable by multiple threads
class Block {
public:
Block();
Block(const std::string &symbol_file, const std::string &param_file);
virtual ~Block();

const Symbol& symbol() const { return sym_; }
Symbol symbol(const std::string &output_node) const { return sym_.GetInternals()[output_node]; }
const std::map<std::string, NDArray>& arg_map() const { return arg_map_; }
const std::map<std::string, NDArray>& aux_map() const { return aux_map_; }

private:
void load_parameters(const std::string& param_file);

// symbol
Symbol sym_;
// argument arrays
std::map<std::string, NDArray> arg_map_;
// auxiliary arrays
std::map<std::string, NDArray> aux_map_;
};

// Simple helper class to run prediction
// this cannot be shared between threads
class Predictor {
public:
Predictor();
Predictor(const Block &block);
Predictor(const Block &block, const std::string &output_node);
virtual ~Predictor();

// set input array shapes
void set_input_shapes(const std::vector<std::string>& input_names, const std::vector<std::vector<mx_uint>>& input_shapes);

// run prediction
const std::vector<float>& predict(const std::vector<std::vector<mx_float>>& input_data);

private:
static std::mutex mutex_;

void bind_executor();

// context
static const Context context_;
// executor
std::unique_ptr<Executor> exec_;
// symbol
Symbol sym_;
// argument arrays
std::map<std::string, NDArray> arg_map_;
// auxiliary arrays
std::map<std::string, NDArray> aux_map_;
// output of the prediction
std::vector<float> pred_;
// names of the input nodes
std::vector<std::string> input_names_;

};

} /* namespace cpp */
} /* namespace mxnet */

#endif /* PHYSICSTOOLS_MXNET_MXNETCPPPREDICTOR_H_ */
Loading

0 comments on commit fe0f9a9

Please sign in to comment.