Skip to content

Commit

Permalink
Extend MVA support to multiple MVA variables
Browse files Browse the repository at this point in the history
  • Loading branch information
makortel committed Jun 16, 2017
1 parent 3e6dc7e commit 264d900
Showing 1 changed file with 50 additions and 30 deletions.
80 changes: 50 additions & 30 deletions Validation/RecoTrack/plugins/TrackingNtuple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,8 @@ class TrackingNtuple : public edm::one::EDAnalyzer<edm::one::SharedResources> {
const TrackerTopology& tTopo,
const std::set<edm::ProductID>& hitProductIds,
const std::map<edm::ProductID, size_t>& seedToCollIndex,
const MVACollection *mvaColl,
const QualityMaskCollection *qualColl
const std::vector<const MVACollection *>& mvaColls,
const std::vector<const QualityMaskCollection *>& qualColls
);

void fillSimHits(const TrackerGeometry& tracker,
Expand Down Expand Up @@ -452,8 +452,7 @@ class TrackingNtuple : public edm::one::EDAnalyzer<edm::one::SharedResources> {
std::vector<edm::EDGetTokenT<edm::View<reco::Track> > > seedTokens_;
std::vector<edm::EDGetTokenT<std::vector<short> > > seedStopReasonTokens_;
edm::EDGetTokenT<edm::View<reco::Track> > trackToken_;
edm::EDGetTokenT<MVACollection> trackMVAToken_;
edm::EDGetTokenT<QualityMaskCollection> trackQualMaskToken_;
std::vector<std::tuple<edm::EDGetTokenT<MVACollection>, edm::EDGetTokenT<QualityMaskCollection> > > mvaQualityCollectionTokens_;
edm::EDGetTokenT<TrackingParticleCollection> trackingParticleToken_;
edm::EDGetTokenT<TrackingParticleRefVector> trackingParticleRefToken_;
edm::EDGetTokenT<ClusterTPAssociation> clusterTPMapToken_;
Expand Down Expand Up @@ -526,8 +525,8 @@ class TrackingNtuple : public edm::one::EDAnalyzer<edm::one::SharedResources> {
std::vector<float> trk_nChi2 ;
std::vector<float> trk_nChi2_1Dmod;
std::vector<float> trk_ndof ;
std::vector<float> trk_mva;
std::vector<unsigned short> trk_qualityMask;
std::vector<std::vector<float>> trk_mvas;
std::vector<std::vector<unsigned short>> trk_qualityMasks;
std::vector<int> trk_q ;
std::vector<unsigned int> trk_nValid ;
std::vector<unsigned int> trk_nInvalid;
Expand Down Expand Up @@ -836,9 +835,11 @@ TrackingNtuple::TrackingNtuple(const edm::ParameterSet& iConfig):
tracer_.depth(-2); // as in SimTracker/TrackHistory/src/TrackClassifier.cc

if(includeMVA_) {
auto mvaTag = iConfig.getUntrackedParameter<std::string>("trackMVAs");
trackMVAToken_ = consumes<MVACollection>(edm::InputTag(mvaTag, "MVAValues"));
trackQualMaskToken_ = consumes<QualityMaskCollection>(edm::InputTag(mvaTag, "QualityMasks"));
mvaQualityCollectionTokens_ = edm::vector_transform(iConfig.getUntrackedParameter<std::vector<std::string> >("trackMVAs"),
[&](const std::string& tag) {
return std::make_tuple(consumes<MVACollection>(edm::InputTag(tag, "MVAValues")),
consumes<QualityMaskCollection>(edm::InputTag(tag, "QualityMasks")));
});
}

usesResource(TFileService::kSharedResource);
Expand Down Expand Up @@ -885,8 +886,16 @@ TrackingNtuple::TrackingNtuple(const edm::ParameterSet& iConfig):
t->Branch("trk_nChi2_1Dmod", &trk_nChi2_1Dmod);
t->Branch("trk_ndof" , &trk_ndof);
if(includeMVA_) {
t->Branch("trk_mva" , &trk_mva);
t->Branch("trk_qualityMask", &trk_qualityMask);
trk_mvas.resize(mvaQualityCollectionTokens_.size());
trk_qualityMasks.resize(mvaQualityCollectionTokens_.size());
if(!trk_mvas.empty()) {
t->Branch("trk_mva" , &(trk_mvas[0]));
t->Branch("trk_qualityMask", &(trk_qualityMasks[0]));
for(size_t i=1; i<trk_mvas.size(); ++i) {
t->Branch(("trk_mva"+std::to_string(i+1)).c_str(), &(trk_mvas[i]));
t->Branch(("trk_qualityMask"+std::to_string(i+1)).c_str(), &(trk_qualityMasks[i]));
}
}
}
t->Branch("trk_q" , &trk_q);
t->Branch("trk_nValid" , &trk_nValid );
Expand Down Expand Up @@ -1209,8 +1218,12 @@ void TrackingNtuple::clearVariables() {
trk_nChi2 .clear();
trk_nChi2_1Dmod.clear();
trk_ndof .clear();
trk_mva .clear();
trk_qualityMask.clear();
for(auto& mva: trk_mvas) {
mva.clear();
}
for(auto& mask: trk_qualityMasks) {
mask.clear();
}
trk_q .clear();
trk_nValid .clear();
trk_nInvalid .clear();
Expand Down Expand Up @@ -1597,26 +1610,31 @@ void TrackingNtuple::analyze(const edm::Event& iEvent, const edm::EventSetup& iS
for(edm::View<Track>::size_type i=0; i<tracks.size(); ++i) {
trackRefs.push_back(tracks.refAt(i));
}
const MVACollection *mvaColl = nullptr;
const QualityMaskCollection *qualColl = nullptr;
std::vector<const MVACollection *> mvaColls;
std::vector<const QualityMaskCollection *> qualColls;
if(includeMVA_) {
edm::Handle<MVACollection> hmva;
iEvent.getByToken(trackMVAToken_, hmva);
mvaColl = hmva.product();
if(mvaColl->size() != tracks.size())
throw cms::Exception("LogicError") << "Inconsistent track collection size (" << tracks.size() << ") and MVA collection size (" << mvaColl->size() << ")";

edm::Handle<QualityMaskCollection> hqual;
iEvent.getByToken(trackQualMaskToken_, hqual);
qualColl = hqual.product();
if(qualColl->size() != tracks.size())
throw cms::Exception("LogicError") << "Inconsistent track collection size (" << tracks.size() << ") and quality mask collection size (" << qualColl->size() << ")";

for(const auto& tokenTpl: mvaQualityCollectionTokens_) {
iEvent.getByToken(std::get<0>(tokenTpl), hmva);
iEvent.getByToken(std::get<1>(tokenTpl), hqual);

mvaColls.push_back(hmva.product());
qualColls.push_back(hqual.product());
if(mvaColls.back()->size() != tracks.size()) {
throw cms::Exception("Configuration") << "Inconsistency in track collection and MVA sizes. Track collection has " << tracks.size() << " tracks, whereas the MVA " << (mvaColls.size()-1) << " has " << mvaColls.back()->size() << " entries. Double-check your configuration.";
}
if(qualColls.back()->size() != tracks.size()) {
throw cms::Exception("Configuration") << "Inconsistency in track collection and quality mask sizes. Track collection has " << tracks.size() << " tracks, whereas the quality mask " << (qualColls.size()-1) << " has " << qualColls.back()->size() << " entries. Double-check your configuration.";
}
}
}

edm::Handle<reco::VertexCollection> vertices;
iEvent.getByToken(vertexToken_, vertices);

fillTracks(trackRefs, tpCollection, tpKeyToIndex, bs, *vertices, associatorByHits, *theTTRHBuilder, tTopo, hitProductIds, seedCollToOffset, mvaColl, qualColl);
fillTracks(trackRefs, tpCollection, tpKeyToIndex, bs, *vertices, associatorByHits, *theTTRHBuilder, tTopo, hitProductIds, seedCollToOffset, mvaColls, qualColls);

//tracking particles
//sort association maps with simHits
Expand Down Expand Up @@ -2408,8 +2426,8 @@ void TrackingNtuple::fillTracks(const edm::RefToBaseVector<reco::Track>& tracks,
const TrackerTopology& tTopo,
const std::set<edm::ProductID>& hitProductIds,
const std::map<edm::ProductID, size_t>& seedCollToOffset,
const MVACollection *mvaColl,
const QualityMaskCollection *qualColl
const std::vector<const MVACollection *>& mvaColls,
const std::vector<const QualityMaskCollection *>& qualColls
) {
reco::RecoToSimCollection recSimColl = associatorByHits.associateRecoToSim(tracks, tpCollection);
edm::EDConsumerBase::Labels labels;
Expand Down Expand Up @@ -2509,8 +2527,10 @@ void TrackingNtuple::fillTracks(const edm::RefToBaseVector<reco::Track>& tracks,
trk_stopReason.push_back(itTrack->stopReason());
trk_isHP .push_back(itTrack->quality(reco::TrackBase::highPurity));
if(includeMVA_) {
trk_mva .push_back((*mvaColl)[iTrack]);
trk_qualityMask.push_back((*qualColl)[iTrack]);
for(size_t i=0; i<trk_mvas.size(); ++i) {
trk_mvas [i].push_back( (*(mvaColls [i]))[iTrack] );
trk_qualityMasks[i].push_back( (*(qualColls[i]))[iTrack] );
}
}
if(includeSeeds_) {
auto offset = seedCollToOffset.find(itTrack->seedRef().id());
Expand Down Expand Up @@ -2882,7 +2902,7 @@ void TrackingNtuple::fillDescriptions(edm::ConfigurationDescriptions& descriptio
edm::InputTag("muonSeededTrackCandidatesOutIn")
});
desc.addUntracked<edm::InputTag>("tracks", edm::InputTag("generalTracks"));
desc.addUntracked<std::string>("trackMVAs", "generalTracks");
desc.addUntracked<std::vector<std::string> >("trackMVAs", std::vector<std::string>{{"generalTracks"}});
desc.addUntracked<edm::InputTag>("trackingParticles", edm::InputTag("mix", "MergedTrackTruth"));
desc.addUntracked<bool>("trackingParticlesRef", false);
desc.addUntracked<edm::InputTag>("clusterTPMap", edm::InputTag("tpClusterProducer"));
Expand Down

0 comments on commit 264d900

Please sign in to comment.