Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made DeepTauId and DPFIsolation thread-safe #101

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions RecoTauTag/RecoTau/interface/DeepTauBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,17 @@ class DeepTauBase : public edm::stream::EDProducer<> {
virtual void produce(edm::Event& event, const edm::EventSetup& es) override;

private:
virtual tensorflow::Tensor GetPredictions(edm::Event& event, const edm::EventSetup& es) = 0;
virtual void CreateOutputs(edm::Event& event, const tensorflow::Tensor& pred);
virtual tensorflow::Tensor GetPredictions(edm::Event& event, const edm::EventSetup& es,
edm::Handle<TauCollection> taus) = 0;
virtual void CreateOutputs(edm::Event& event, const tensorflow::Tensor& pred, edm::Handle<TauCollection> taus);

protected:
edm::EDGetTokenT<TauCollection> taus_token;
edm::Handle<TauCollection> taus;
std::string graphName;
GraphPtr graph;
tensorflow::Session* session;
std::map<std::string, WPMap> working_points;
OutputCollection outputs;

};

} // namespace deep_tau
Expand Down
22 changes: 11 additions & 11 deletions RecoTauTag/RecoTau/plugins/DPFIsolation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,25 @@ class DPFIsolation : public deep_tau::DeepTauBase {
graphVersion = 1;
else
throw cms::Exception("DPFIsolation") << "unknown version of the graph file.";

tensor = tensorflow::Tensor(tensorflow::DT_FLOAT, {1,
static_cast<int>(GetNumberOfParticles(graphVersion)), static_cast<int>(GetNumberOfFeatures(graphVersion))});
}

private:
virtual tensorflow::Tensor GetPredictions(edm::Event& event, const edm::EventSetup& es) override
virtual tensorflow::Tensor GetPredictions(edm::Event& event, const edm::EventSetup& es,
edm::Handle<TauCollection> taus) override
{
edm::Handle<pat::PackedCandidateCollection> pfcands;
event.getByToken(pfcand_token, pfcands);

edm::Handle<reco::VertexCollection> vertices;
event.getByToken(vtx_token, vertices);

tensorflow::Tensor tensor(tensorflow::DT_FLOAT, {1,
static_cast<int>(GetNumberOfParticles(graphVersion)), static_cast<int>(GetNumberOfFeatures(graphVersion))});

tensorflow::Tensor predictions(tensorflow::DT_FLOAT, { static_cast<int>(taus->size()), 1});

std::vector<tensorflow::Tensor> outputs;

float pfCandPt, pfCandPz, pfCandPtRel, pfCandPzRel, pfCandDr, pfCandDEta, pfCandDPhi, pfCandEta, pfCandDz,
pfCandDzErr, pfCandD0, pfCandD0D0, pfCandD0Dz, pfCandD0Dphi, pfCandPuppiWeight,
pfCandPixHits, pfCandHits, pfCandLostInnerHits, pfCandPdgID, pfCandCharge, pfCandFromPV,
Expand Down Expand Up @@ -378,14 +384,8 @@ class DPFIsolation : public deep_tau::DeepTauBase {

private:
edm::EDGetTokenT<pat::PackedCandidateCollection> pfcand_token;
edm::EDGetTokenT<reco::VertexCollection> vtx_token;

edm::Handle<pat::PackedCandidateCollection> pfcands;
edm::Handle<reco::VertexCollection> vertices;

edm::EDGetTokenT<reco::VertexCollection> vtx_token;
unsigned graphVersion;
tensorflow::Tensor tensor;
std::vector<tensorflow::Tensor> outputs;
};

#include "FWCore/Framework/interface/MakerMacros.h"
Expand Down
11 changes: 5 additions & 6 deletions RecoTauTag/RecoTau/plugins/DeepTauId.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ class DeepTauId : public deep_tau::DeepTauBase {
}

private:
virtual tensorflow::Tensor GetPredictions(edm::Event& event, const edm::EventSetup& es) override
virtual tensorflow::Tensor GetPredictions(edm::Event& event, const edm::EventSetup& es,
edm::Handle<TauCollection> taus) override
{
edm::Handle<pat::ElectronCollection> electrons;
event.getByToken(electrons_token, electrons);
Expand Down Expand Up @@ -306,9 +307,9 @@ class DeepTauId : public deep_tau::DeepTauBase {
void SetInputs(const TauCollection& taus, size_t tau_index, tensorflow::Tensor& inputs,
const ElectronCollection& electrons, const MuonCollection& muons) const
{

static constexpr bool check_all_set = false;
static constexpr float magic_number = -42;
static const TauIdMVAAuxiliaries clusterVariables;
const auto& get = [&](int var_index) -> float& { return inputs.matrix<float>()(tau_index, var_index); };
const TauType& tau = taus.at(tau_index);
auto leadChargedHadrCand = dynamic_cast<const pat::PackedCandidate*>(tau.leadChargedHadrCand().get());
Expand Down Expand Up @@ -604,9 +605,8 @@ class DeepTauId : public deep_tau::DeepTauBase {
const double dR2 = deltaR*deltaR;
const pat::Electron* matched_ele = nullptr;
for(const auto& ele : electrons) {
if(reco::deltaR2(tau.p4(), ele.p4()) < dR2 &&
(!matched_ele || matched_ele->pt() < ele.pt())) {
matched_ele = &ele;
if(reco::deltaR2(tau.p4(), ele.p4()) < dR2 && (!matched_ele || matched_ele->pt() < ele.pt())) {
matched_ele = &ele;
}
}
return matched_ele;
Expand All @@ -616,7 +616,6 @@ class DeepTauId : public deep_tau::DeepTauBase {
edm::EDGetTokenT<ElectronCollection> electrons_token;
edm::EDGetTokenT<MuonCollection> muons_token;
std::string input_layer, output_layer;
TauIdMVAAuxiliaries clusterVariables;
};

#include "FWCore/Framework/interface/MakerMacros.h"
Expand Down
4 changes: 3 additions & 1 deletion RecoTauTag/RecoTau/python/tools/runTauIdMVA.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def runTauID(self):
"Loose": "0.999755",
"Medium": "0.999854",
"Tight": "0.999886",
"VTight": "0.99994",
"VTight": "0.999944",
"VVTight": "0.9999971"
},

Expand Down Expand Up @@ -666,6 +666,8 @@ def runTauID(self):

if "DPFTau_2016_v1" in self.toKeep:
print "Adding DPFTau isolation (v1)"
print "WARNING: WPs are not defined for DPFTau_2016_v1"
print "WARNING: The score of DPFTau_2016_v1 is inverted: i.e. for Sig->0, for Bkg->1 with -1 for undefined input (preselection not passed)."

working_points = {
"all": {"Tight" : "0.123"} #FIXME: define WP
Expand Down
8 changes: 5 additions & 3 deletions RecoTauTag/RecoTau/src/DeepTauBase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,14 @@ DeepTauBase::~DeepTauBase()

void DeepTauBase::produce(edm::Event& event, const edm::EventSetup& es)
{
edm::Handle<TauCollection> taus;
event.getByToken(taus_token, taus);
const tensorflow::Tensor& pred = GetPredictions(event, es);
CreateOutputs(event, pred);

const tensorflow::Tensor& pred = GetPredictions(event, es, taus);
CreateOutputs(event, pred, taus);
}

void DeepTauBase::CreateOutputs(edm::Event& event, const tensorflow::Tensor& pred)
void DeepTauBase::CreateOutputs(edm::Event& event, const tensorflow::Tensor& pred, edm::Handle<TauCollection> taus)
{
for(const auto& output_desc : outputs) {
auto result_map = output_desc.second.get_value(taus, pred, working_points.at(output_desc.first));
Expand Down
4 changes: 3 additions & 1 deletion RecoTauTag/RecoTau/test/runDeepTauIDsOnMiniAOD.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
updatedTauName = updatedTauName,
toKeep = [ "2017v2", "dR0p32017v2", "newDM2017v2",
"deepTau2017v1",
"DPFTau_2016_v0","DPFTau_2016_v1"])
"DPFTau_2016_v0",
#"DPFTau_2016_v1"
])
tauIdEmbedder.runTauID()

# Output definition
Expand Down