Skip to content

Commit

Permalink
Made DeepTauId and DPFIsolation thread-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
MRD2F committed Oct 27, 2018
1 parent e1f055d commit 0417afd
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 23 deletions.
7 changes: 3 additions & 4 deletions RecoTauTag/RecoTau/interface/DeepTauBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,17 @@ class DeepTauBase : public edm::stream::EDProducer<> {
virtual void produce(edm::Event& event, const edm::EventSetup& es) override;

private:
virtual tensorflow::Tensor GetPredictions(edm::Event& event, const edm::EventSetup& es) = 0;
virtual void CreateOutputs(edm::Event& event, const tensorflow::Tensor& pred);
virtual tensorflow::Tensor GetPredictions(edm::Event& event, const edm::EventSetup& es,
edm::Handle<TauCollection> taus) = 0;
virtual void CreateOutputs(edm::Event& event, const tensorflow::Tensor& pred, edm::Handle<TauCollection> taus);

protected:
edm::EDGetTokenT<TauCollection> taus_token;
edm::Handle<TauCollection> taus;
std::string graphName;
GraphPtr graph;
tensorflow::Session* session;
std::map<std::string, WPMap> working_points;
OutputCollection outputs;

};

} // namespace deep_tau
Expand Down
19 changes: 9 additions & 10 deletions RecoTauTag/RecoTau/plugins/DPFIsolation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,21 @@ class DPFIsolation : public deep_tau::DeepTauBase {
graphVersion = 1;
else
throw cms::Exception("DPFIsolation") << "unknown version of the graph file.";

tensor = tensorflow::Tensor(tensorflow::DT_FLOAT, {1,
static_cast<int>(GetNumberOfParticles(graphVersion)), static_cast<int>(GetNumberOfFeatures(graphVersion))});
}

private:
virtual tensorflow::Tensor GetPredictions(edm::Event& event, const edm::EventSetup& es) override
virtual tensorflow::Tensor GetPredictions(edm::Event& event, const edm::EventSetup& es,
edm::Handle<TauCollection> taus) override
{
edm::Handle<pat::PackedCandidateCollection> pfcands;
event.getByToken(pfcand_token, pfcands);

edm::Handle<reco::VertexCollection> vertices;
event.getByToken(vtx_token, vertices);

tensorflow::Tensor tensor(tensorflow::DT_FLOAT, {1,
static_cast<int>(GetNumberOfParticles(graphVersion)), static_cast<int>(GetNumberOfFeatures(graphVersion))});

tensorflow::Tensor predictions(tensorflow::DT_FLOAT, { static_cast<int>(taus->size()), 1});

float pfCandPt, pfCandPz, pfCandPtRel, pfCandPzRel, pfCandDr, pfCandDEta, pfCandDPhi, pfCandEta, pfCandDz,
Expand Down Expand Up @@ -378,13 +382,8 @@ class DPFIsolation : public deep_tau::DeepTauBase {

private:
edm::EDGetTokenT<pat::PackedCandidateCollection> pfcand_token;
edm::EDGetTokenT<reco::VertexCollection> vtx_token;

edm::Handle<pat::PackedCandidateCollection> pfcands;
edm::Handle<reco::VertexCollection> vertices;

edm::EDGetTokenT<reco::VertexCollection> vtx_token;
unsigned graphVersion;
tensorflow::Tensor tensor;
std::vector<tensorflow::Tensor> outputs;
};

Expand Down
10 changes: 5 additions & 5 deletions RecoTauTag/RecoTau/plugins/DeepTauId.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ class DeepTauId : public deep_tau::DeepTauBase {
}

private:
virtual tensorflow::Tensor GetPredictions(edm::Event& event, const edm::EventSetup& es) override
virtual tensorflow::Tensor GetPredictions(edm::Event& event, const edm::EventSetup& es,
edm::Handle<TauCollection> taus) override
{
edm::Handle<pat::ElectronCollection> electrons;
event.getByToken(electrons_token, electrons);
Expand Down Expand Up @@ -306,9 +307,9 @@ class DeepTauId : public deep_tau::DeepTauBase {
void SetInputs(const TauCollection& taus, size_t tau_index, tensorflow::Tensor& inputs,
const ElectronCollection& electrons, const MuonCollection& muons) const
{

static constexpr bool check_all_set = false;
static constexpr float magic_number = -42;
static const TauIdMVAAuxiliaries clusterVariables;
const auto& get = [&](int var_index) -> float& { return inputs.matrix<float>()(tau_index, var_index); };
const TauType& tau = taus.at(tau_index);
auto leadChargedHadrCand = dynamic_cast<const pat::PackedCandidate*>(tau.leadChargedHadrCand().get());
Expand Down Expand Up @@ -604,9 +605,8 @@ class DeepTauId : public deep_tau::DeepTauBase {
const double dR2 = deltaR*deltaR;
const pat::Electron* matched_ele = nullptr;
for(const auto& ele : electrons) {
if(reco::deltaR2(tau.p4(), ele.p4()) < dR2 &&
(!matched_ele || matched_ele->pt() < ele.pt())) {
matched_ele = &ele;
if(reco::deltaR2(tau.p4(), ele.p4()) < dR2 && (!matched_ele || matched_ele->pt() < ele.pt())) {
matched_ele = &ele;
}
}
return matched_ele;
Expand Down
2 changes: 1 addition & 1 deletion RecoTauTag/RecoTau/python/tools/runTauIdMVA.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def runTauID(self):
"Loose": "0.999755",
"Medium": "0.999854",
"Tight": "0.999886",
"VTight": "0.99994",
"VTight": "0.999944",
"VVTight": "0.9999971"
},

Expand Down
8 changes: 5 additions & 3 deletions RecoTauTag/RecoTau/src/DeepTauBase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,14 @@ DeepTauBase::~DeepTauBase()

void DeepTauBase::produce(edm::Event& event, const edm::EventSetup& es)
{
edm::Handle<TauCollection> taus;
event.getByToken(taus_token, taus);
const tensorflow::Tensor& pred = GetPredictions(event, es);
CreateOutputs(event, pred);

const tensorflow::Tensor& pred = GetPredictions(event, es, taus);
CreateOutputs(event, pred, taus);
}

void DeepTauBase::CreateOutputs(edm::Event& event, const tensorflow::Tensor& pred)
void DeepTauBase::CreateOutputs(edm::Event& event, const tensorflow::Tensor& pred, edm::Handle<TauCollection> taus)
{
for(const auto& output_desc : outputs) {
auto result_map = output_desc.second.get_value(taus, pred, working_points.at(output_desc.first));
Expand Down

0 comments on commit 0417afd

Please sign in to comment.