Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GBRForest implementation without TMVA #24432

Merged
merged 19 commits into from
Oct 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CommonTools/MVAUtils/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
<use name="CondFormats/EgammaObjects"/>
<use name="FWCore/ParameterSet"/>
<use name="FWCore/Utilities"/>
<use name="CondFormats/DataRecord"/>
<use name="roottmva"/>
<use name="tinyxml2"/>

<export>
<lib name="1"/>
</export>
27 changes: 27 additions & 0 deletions CommonTools/MVAUtils/interface/GBRForestTools.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef CommonTools_MVAUtils_GBRForestTools_h
#define CommonTools_MVAUtils_GBRForestTools_h

//--------------------------------------------------------------------------------------------------
//
// GRBForestTools
//
// Utility to parse an XML weights files specifying an ensemble of decision trees into a GRBForest.
//
// Author: Jonas Rembser
//--------------------------------------------------------------------------------------------------


#include "CondFormats/EgammaObjects/interface/GBRForest.h"
#include "FWCore/ParameterSet/interface/FileInPath.h"

#include <memory>

// Create a GBRForest from an XML weight file
std::unique_ptr<const GBRForest> createGBRForest(const std::string &weightsFile);
std::unique_ptr<const GBRForest> createGBRForest(const edm::FileInPath &weightsFile);

// Overloaded versions which are taking string vectors by reference to strore the variable names in
std::unique_ptr<const GBRForest> createGBRForest(const std::string &weightsFile, std::vector<std::string> &varNames);
std::unique_ptr<const GBRForest> createGBRForest(const edm::FileInPath &weightsFile, std::vector<std::string> &varNames);

#endif
50 changes: 50 additions & 0 deletions CommonTools/MVAUtils/interface/TMVAEvaluator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef CommonTools_MVAUtils_TMVAEvaluator_h
#define CommonTools_MVAUtils_TMVAEvaluator_h

#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <vector>

#include "CondFormats/EgammaObjects/interface/GBRForest.h"
#include "FWCore/Framework/interface/EventSetup.h"
#include "FWCore/Utilities/interface/thread_safety_macros.h"
#include "TMVA/IMethod.h"
#include "TMVA/Reader.h"

class TMVAEvaluator {

public:
TMVAEvaluator();

void initialize(const std::string& options, const std::string& method, const std::string& weightFile,
const std::vector<std::string>& variables, const std::vector<std::string>& spectators,
bool useGBRForest = false, bool useAdaBoost = false);

void initializeGBRForest(const GBRForest* gbrForest, const std::vector<std::string>& variables,
const std::vector<std::string>& spectators, bool useAdaBoost = false);

void initializeGBRForest(const edm::EventSetup& iSetup, const std::string& label,
const std::vector<std::string>& variables, const std::vector<std::string>& spectators,
bool useAdaBoost = false);

float evaluateTMVA(const std::map<std::string, float>& inputs, bool useSpectators) const;
float evaluateGBRForest(const std::map<std::string, float>& inputs) const;
float evaluate(const std::map<std::string, float>& inputs, bool useSpectators = false) const;

private:
bool mIsInitialized;
bool mUsingGBRForest;
bool mUseAdaBoost;

std::string mMethod;
mutable std::mutex m_mutex;
CMS_THREAD_GUARD(m_mutex) std::unique_ptr<TMVA::Reader> mReader;
std::shared_ptr<const GBRForest> mGBRForest;

CMS_THREAD_GUARD(m_mutex) mutable std::map<std::string, std::pair<size_t, float>> mVariables;
CMS_THREAD_GUARD(m_mutex) mutable std::map<std::string, std::pair<size_t, float>> mSpectators;
};

#endif // CommonTools_Utils_TMVAEvaluator_h
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,21 @@
* =====================================================================================
*/

#ifndef TMVAZIPREADER_7RXIGO70
#define TMVAZIPREADER_7RXIGO70
#ifndef CommonTools_MVAUtils_TMVAZipReader_h
#define CommonTools_MVAUtils_TMVAZipReader_h

#include "TMVA/Reader.h"
#include "TMVA/IMethod.h"
#include "TMVA/Reader.h"
#include <string>

namespace reco {
namespace details {

bool hasEnding(std::string const &fullString, std::string const &ending);
namespace reco::details {

bool hasEnding(std::string const& fullString, std::string const& ending);
char* readGzipFile(const std::string& weightFile);

TMVA::IMethod* loadTMVAWeights(TMVA::Reader* reader, const std::string& method,
const std::string& weightFile, bool verbose=false);
TMVA::IMethod* loadTMVAWeights(
TMVA::Reader* reader, const std::string& method, const std::string& weightFile, bool verbose = false);

}

}}
#endif /* end of include guard: TMVAZIPREADER_7RXIGO70 */
#endif
8 changes: 8 additions & 0 deletions CommonTools/MVAUtils/plugins/BuildFile.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<use name="FWCore/Framework"/>
<use name="FWCore/PluginManager"/>
<use name="FWCore/ParameterSet"/>
<use name="CondCore/DBOutputService"/>
<use name="CondFormats/EgammaObjects"/>
<use name="CommonTools/MVAUtils"/>
<use name="boost"/>
<flags EDM_PLUGIN="1"/>
84 changes: 84 additions & 0 deletions CommonTools/MVAUtils/plugins/GBRForestWriter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#include "CommonTools/MVAUtils/plugins/GBRForestWriter.h"

#include "FWCore/Utilities/interface/Exception.h"

#include "CondCore/DBOutputService/interface/PoolDBOutputService.h"
#include "FWCore/ServiceRegistry/interface/Service.h"

#include "CommonTools/MVAUtils/interface/GBRForestTools.h"

#include <TFile.h>

GBRForestWriter::GBRForestWriter(const edm::ParameterSet& cfg)
: moduleLabel_(cfg.getParameter<std::string>("@module_label"))
{
edm::VParameterSet cfgJobs = cfg.getParameter<edm::VParameterSet>("jobs");
for (edm::VParameterSet::const_iterator cfgJob = cfgJobs.begin(); cfgJob != cfgJobs.end(); ++cfgJob) {
jobEntryType* job = new jobEntryType(*cfgJob);
jobs_.push_back(job);
}
}

GBRForestWriter::~GBRForestWriter()
{
for (std::vector<jobEntryType*>::iterator it = jobs_.begin(); it != jobs_.end(); ++it) {
delete (*it);
}
}

void GBRForestWriter::analyze(const edm::Event&, const edm::EventSetup&)
{

for (std::vector<jobEntryType*>::iterator job = jobs_.begin(); job != jobs_.end(); ++job) {
std::map<std::string, const GBRForest*> gbrForests; // key = name
for (std::vector<categoryEntryType*>::iterator category = (*job)->categories_.begin();
category != (*job)->categories_.end(); ++category) {
const GBRForest* gbrForest = nullptr;
if ((*category)->inputFileType_ == categoryEntryType::kXML) {
gbrForest = createGBRForest((*category)->inputFileName_).release();
} else if ((*category)->inputFileType_ == categoryEntryType::kGBRForest) {
TFile* inputFile = new TFile((*category)->inputFileName_.data());
// gbrForest = dynamic_cast<GBRForest*>(inputFile->Get((*category)->gbrForestName_.data())); // CV:
// dynamic_cast<GBRForest*> fails for some reason ?!
gbrForest = (GBRForest*)inputFile->Get((*category)->gbrForestName_.data());
delete inputFile;
}
if (!gbrForest)
throw cms::Exception("GBRForestWriter")
<< " Failed to load GBRForest = " << (*category)->gbrForestName_.data()
<< " from file = " << (*category)->inputFileName_ << " !!\n";
gbrForests[(*category)->gbrForestName_] = gbrForest;
}
if ((*job)->outputFileType_ == jobEntryType::kGBRForest) {
TFile* outputFile = new TFile((*job)->outputFileName_.data(), "RECREATE");

for (std::map<std::string, const GBRForest*>::iterator gbrForest = gbrForests.begin();
gbrForest != gbrForests.end(); ++gbrForest) {
outputFile->WriteObject(gbrForest->second, gbrForest->first.data());
}
delete outputFile;
} else if ((*job)->outputFileType_ == jobEntryType::kSQLLite) {
edm::Service<cond::service::PoolDBOutputService> dbService;
if (!dbService.isAvailable())
throw cms::Exception("GBRForestWriter") << " Failed to access PoolDBOutputService !!\n";

for (std::map<std::string, const GBRForest*>::iterator gbrForest = gbrForests.begin();
gbrForest != gbrForests.end(); ++gbrForest) {
std::string outputRecord = (*job)->outputRecord_;
if (gbrForests.size() > 1)
outputRecord.append("_").append(gbrForest->first);
dbService->writeOne(gbrForest->second, dbService->beginOfTime(), outputRecord);
}
}

// gbrforest deletion
for (std::map<std::string, const GBRForest*>::iterator gbrForest = gbrForests.begin();
gbrForest != gbrForests.end(); ++gbrForest) {
delete gbrForest->second;
}
}
}

#include "FWCore/Framework/interface/MakerMacros.h"

DEFINE_FWK_MODULE(GBRForestWriter);
117 changes: 117 additions & 0 deletions CommonTools/MVAUtils/plugins/GBRForestWriter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#ifndef CommonTools_MVAUtils_GBRForestWriter_h
#define CommonTools_MVAUtils_GBRForestWriter_h

/** \class GBRForestWriter
*
* Read GBRForest objects from ROOT file input
* and store it in SQL-lite output file
*
* \authors Christian Veelken, LLR
*
*/

#include "FWCore/Framework/interface/EDAnalyzer.h"
#include "FWCore/Framework/interface/Event.h"
#include "FWCore/Framework/interface/Frameworkfwd.h"
#include "FWCore/ParameterSet/interface/ParameterSet.h"

#include <string>
#include <vector>

class GBRForestWriter : public edm::EDAnalyzer {
public:
GBRForestWriter(const edm::ParameterSet&);
~GBRForestWriter() override;

private:
void analyze(const edm::Event&, const edm::EventSetup&) override;

std::string moduleLabel_;

bool hasRun_;

typedef std::vector<std::string> vstring;

struct categoryEntryType {
categoryEntryType(const edm::ParameterSet& cfg)
{
if (cfg.existsAs<edm::FileInPath>("inputFileName")) {
edm::FileInPath inputFileName_fip = cfg.getParameter<edm::FileInPath>("inputFileName");
inputFileName_ = inputFileName_fip.fullPath();
} else if (cfg.existsAs<std::string>("inputFileName")) {
inputFileName_ = cfg.getParameter<std::string>("inputFileName");
} else
throw cms::Exception("GBRForestWriter") << " Undefined Configuration Parameter 'inputFileName !!\n";
std::string inputFileType_string = cfg.getParameter<std::string>("inputFileType");
if (inputFileType_string == "XML")
inputFileType_ = kXML;
else if (inputFileType_string == "GBRForest")
inputFileType_ = kGBRForest;
else
throw cms::Exception("GBRForestWriter")
<< " Invalid Configuration Parameter 'inputFileType' = " << inputFileType_string << " !!\n";
if (inputFileType_ == kXML) {
inputVariables_ = cfg.getParameter<vstring>("inputVariables");
spectatorVariables_ = cfg.getParameter<vstring>("spectatorVariables");
methodName_ = cfg.getParameter<std::string>("methodName");
gbrForestName_
= (cfg.existsAs<std::string>("gbrForestName") ? cfg.getParameter<std::string>("gbrForestName")
: methodName_);
} else {
gbrForestName_ = cfg.getParameter<std::string>("gbrForestName");
}
}
~categoryEntryType() {}
std::string inputFileName_;
enum { kXML, kGBRForest };
int inputFileType_;
vstring inputVariables_;
vstring spectatorVariables_;
std::string gbrForestName_;
std::string methodName_;
};
struct jobEntryType {
jobEntryType(const edm::ParameterSet& cfg)
{
if (cfg.exists("categories")) {
edm::VParameterSet cfgCategories = cfg.getParameter<edm::VParameterSet>("categories");
for (edm::VParameterSet::const_iterator cfgCategory = cfgCategories.begin();
cfgCategory != cfgCategories.end(); ++cfgCategory) {
categoryEntryType* category = new categoryEntryType(*cfgCategory);
categories_.push_back(category);
}
} else {
categoryEntryType* category = new categoryEntryType(cfg);
categories_.push_back(category);
}
std::string outputFileType_string = cfg.getParameter<std::string>("outputFileType");
if (outputFileType_string == "GBRForest")
outputFileType_ = kGBRForest;
else if (outputFileType_string == "SQLLite")
outputFileType_ = kSQLLite;
else
throw cms::Exception("GBRForestWriter")
<< " Invalid Configuration Parameter 'outputFileType' = " << outputFileType_string << " !!\n";
if (outputFileType_ == kGBRForest) {
outputFileName_ = cfg.getParameter<std::string>("outputFileName");
}
if (outputFileType_ == kSQLLite) {
outputRecord_ = cfg.getParameter<std::string>("outputRecord");
}
}
~jobEntryType()
{
for (std::vector<categoryEntryType*>::iterator it = categories_.begin(); it != categories_.end(); ++it) {
delete (*it);
}
}
std::vector<categoryEntryType*> categories_;
enum { kGBRForest, kSQLLite };
int outputFileType_;
std::string outputFileName_;
std::string outputRecord_;
};
std::vector<jobEntryType*> jobs_;
};

#endif
Loading