Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

forward-porting DNN-related from branch cms-tau-pog:CMSSW_9_4_X_tau_pog_DNNTauIDs #111

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
7df8fdb
Remove assigning value of variable to itself
MRD2F Nov 29, 2018
8c3fb20
- Implemented on runTauIdMVA the option to work with new training fil…
MRD2F Oct 31, 2018
0eb0fda
- Implementation of global cache to avoid reloading graph for each th…
MRD2F Nov 2, 2018
c34d583
Applied changes on DeepTauBase to allow load new training files using…
MRD2F Nov 5, 2018
45c0003
Implemented TauWPThreshold class.
kandrosov Nov 2, 2018
ea94956
Remove the qm.pb input files and leaving just the quantized and the o…
MRD2F Nov 5, 2018
ee097de
-Overall, changes to improve memory usage, among these are:
MRD2F Nov 6, 2018
5d4d15a
Applied style comments
MRD2F Nov 8, 2018
a82f820
Applied style comments
MRD2F Nov 8, 2018
b595d46
Applied comments
MRD2F Nov 8, 2018
ac1ce8e
Change to be by default the original training file for deepTau, inste…
MRD2F Nov 8, 2018
505f4e8
Changes regarding forward-porting DNN-related developments from the P…
MRD2F Nov 17, 2018
cdce09b
Applied commets of previus PR
MRD2F Nov 13, 2018
6cb1168
cleaning code
MRD2F Nov 13, 2018
92df2ae
Modification in the config to work with new label in files
MRD2F Nov 14, 2018
8a0ab25
Applied comment about the expected format of name of training file
MRD2F Nov 14, 2018
72b193a
Fix in last commit
MRD2F Nov 14, 2018
cedab33
Applied last comments
MRD2F Nov 14, 2018
df546a5
Changes regarding forward-porting DNN-related developments from the P…
MRD2F Nov 17, 2018
95cca7a
Applied @perrotta comments on 104X
MRD2F Nov 21, 2018
daac1a3
Fix error
MRD2F Nov 21, 2018
3d5ab7c
Applied comments
MRD2F Nov 21, 2018
4344f17
Applied comments
MRD2F Nov 22, 2018
139ec49
Fix merge problem
MRD2F Nov 22, 2018
3142b41
Applied a few commets
MRD2F Nov 22, 2018
c8ef218
Applied more changes
MRD2F Nov 22, 2018
1da62c7
Applied a few small followups
MRD2F Nov 22, 2018
36a67b4
Fixed error on DPFIsolation
MRD2F Nov 22, 2018
99dbf40
Update DPFIsolation.cc
mbluj Nov 23, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions RecoTauTag/RecoTau/interface/DeepTauBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "FWCore/Framework/interface/stream/EDProducer.h"
#include "FWCore/ParameterSet/interface/ParameterSet.h"
#include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
#include "tensorflow/core/util/memmapped_file_system.h"
#include "DataFormats/PatCandidates/interface/Electron.h"
#include "DataFormats/PatCandidates/interface/Muon.h"
#include "DataFormats/PatCandidates/interface/Tau.h"
Expand All @@ -22,10 +23,39 @@
#include "RecoTauTag/RecoTau/interface/PFRecoTauClusterVariables.h"
#include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
#include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
#include <TF1.h>

namespace deep_tau {

class DeepTauBase : public edm::stream::EDProducer<> {
class TauWPThreshold {
public:
explicit TauWPThreshold(const std::string& cut_str);
double operator()(const pat::Tau& tau) const;

private:
std::unique_ptr<TF1> fn_;
double value_;
};

class DeepTauCache {
public:
using GraphPtr = std::shared_ptr<tensorflow::GraphDef>;

DeepTauCache(const std::string& graph_name, bool mem_mapped);
~DeepTauCache();

// A Session allows concurrent calls to Run(), though a Session must
// be created / extended by a single thread.
tensorflow::Session& getSession() const { return *session_; }
const tensorflow::GraphDef& getGraph() const { return *graph_; }

private:
GraphPtr graph_;
tensorflow::Session* session_;
std::unique_ptr<tensorflow::MemmappedEnv> memmappedEnv_;
};

class DeepTauBase : public edm::stream::EDProducer<edm::GlobalCache<DeepTauCache>> {
public:
using TauType = pat::Tau;
using TauDiscriminator = pat::PATTauDiscriminator;
Expand All @@ -35,17 +65,15 @@ class DeepTauBase : public edm::stream::EDProducer<> {
using ElectronCollection = pat::ElectronCollection;
using MuonCollection = pat::MuonCollection;
using LorentzVectorXYZ = ROOT::Math::LorentzVector<ROOT::Math::PxPyPzE4D<double>>;
using GraphPtr = std::shared_ptr<tensorflow::GraphDef>;
using Cutter = StringObjectFunction<TauType>;
using Cutter = TauWPThreshold;
using CutterPtr = std::unique_ptr<Cutter>;
using WPMap = std::map<std::string, CutterPtr>;


struct Output {
using ResultMap = std::map<std::string, std::unique_ptr<TauDiscriminator>>;
std::vector<size_t> num, den;
std::vector<size_t> num_, den_;

Output(const std::vector<size_t>& _num, const std::vector<size_t>& _den) : num(_num), den(_den) {}
Output(const std::vector<size_t>& num, const std::vector<size_t>& den) : num_(num), den_(den) {}

ResultMap get_value(const edm::Handle<TauCollection>& taus, const tensorflow::Tensor& pred,
const WPMap& working_points) const;
Expand All @@ -54,27 +82,25 @@ class DeepTauBase : public edm::stream::EDProducer<> {
using OutputCollection = std::map<std::string, Output>;


DeepTauBase(const edm::ParameterSet& cfg, const OutputCollection& outputs);
virtual ~DeepTauBase();
DeepTauBase(const edm::ParameterSet& cfg, const OutputCollection& outputs, const DeepTauCache* cache);
virtual ~DeepTauBase() {}

virtual void produce(edm::Event& event, const edm::EventSetup& es) override;

static std::unique_ptr<DeepTauCache> initializeGlobalCache(const edm::ParameterSet& cfg);
static void globalEndJob(const DeepTauCache* cache){ }
private:
virtual tensorflow::Tensor GetPredictions(edm::Event& event, const edm::EventSetup& es,
virtual tensorflow::Tensor getPredictions(edm::Event& event, const edm::EventSetup& es,
edm::Handle<TauCollection> taus) = 0;
virtual void CreateOutputs(edm::Event& event, const tensorflow::Tensor& pred, edm::Handle<TauCollection> taus);
virtual void createOutputs(edm::Event& event, const tensorflow::Tensor& pred, edm::Handle<TauCollection> taus);

protected:
edm::EDGetTokenT<TauCollection> taus_token;
std::string graphName;
GraphPtr graph;
tensorflow::Session* session;
std::map<std::string, WPMap> working_points;
OutputCollection outputs;
edm::EDGetTokenT<TauCollection> tausToken_;
std::map<std::string, WPMap> workingPoints_;
OutputCollection outputs_;
const DeepTauCache* cache_;
};

} // namespace deep_tau



#endif
Loading