Skip to content

Commit

Permalink
[RF] Replace multi-range fit logic in RooAbsPdf
Browse files Browse the repository at this point in the history
Multi-range fits in RooFit are more complicated than they should be.

In principle, all that is required is to change the normalization range
of the PDF to the union of the ranges.

There is a RooAbdPdf interface to suggest that this could be done easily
like this:
```C++
pdf.setNormRange("range1,range2")
```

But this didn't wor for RooAddPdfs, which is probably why it was chosen
to implement mulit-range fits as a sum of separate RooNLLVars. In the
old test statistics framework. But in this case, the PDFs are normalized
separately, and extra terms need to be introduced to correct for that.
This resulted in lots of complicated code, and still there are issues
like root-project#11447, i.e. is still doesn't work for simultaneous fits.

Now that in the previous commits the `RooAddPdf` behavior for
`setNormRange()` was fixed, one can actually use comma-separated
normalization ranges for a single PDF in a multi-range fit. With this
commit, this is done in the old RooFit test statistics classes.

This has sevaral advantages:

1. Fixes issues with multi-range simultaneous fits
2. It's now in harmony with the logic in the new BatchMode
3. Some speedup because there are less nodes in the computation graph
4. Less code required
5. Logic is easier to understand

Maybe there are also new bugs now, but they can be fixed later. I am
sure that already now the commit fixes more issues that it creates.
  • Loading branch information
guitargeek committed Oct 1, 2022
1 parent 82b21cd commit 9e61bfd
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 120 deletions.
50 changes: 23 additions & 27 deletions roofit/roofitcore/src/RooAbsOptTestStatistic.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ parallelized calculation of test statistics.
#include "RooVectorDataStore.h"
#include "RooBinSamplingPdf.h"

#include "ROOT/StringUtils.hxx"

using namespace std;

ClassImp(RooAbsOptTestStatistic);
Expand Down Expand Up @@ -290,16 +292,27 @@ void RooAbsOptTestStatistic::initSlave(RooAbsReal& real, RooAbsData& indata, con
// ******************************************************************

std::unique_ptr<RooArgSet> origObsSet( real.getObservables(indata) );
RooArgSet* dataObsSet = (RooArgSet*) _dataClone->get() ;
if (rangeName && strlen(rangeName)) {
cxcoutI(Fitting) << "RooAbsOptTestStatistic::ctor(" << GetName() << ") constructing test statistic for sub-range named " << rangeName << endl ;

bool observablesKnowRange = false;
if(auto pdfClone = dynamic_cast<RooAbsPdf*>(_funcClone)) {
pdfClone->setNormRange(rangeName);
}

// Adjust FUNC normalization ranges to requested fitRange, store original ranges for RooAddPdf coefficient interpretation
for (const auto arg : *_funcObsSet) {

RooRealVar* realObs = dynamic_cast<RooRealVar*>(arg) ;
if (realObs) {
if (auto realObs = dynamic_cast<RooRealVar*>(arg)) {

auto tokens = ROOT::Split(rangeName, ",");
for(std::string const& token : tokens) {
if(!realObs->hasRange(token.c_str())) {
std::stringstream errMsg;
errMsg << "The observable \"" << realObs->GetName() << "\" doesn't define the requested range \""
<< token << "\". Replacing it with the default range." << std::endl;
coutW(Fitting) << errMsg.str() << std::endl;
}
}

auto transferRangeAndBinning = [&](RooRealVar & toVar, const char* toName, const char* fromName) {
toVar.setRange(toName, realObs->getMin(fromName),realObs->getMax(fromName));
Expand All @@ -312,22 +325,11 @@ void RooAbsOptTestStatistic::initSlave(RooAbsReal& real, RooAbsData& indata, con
}
};

observablesKnowRange |= realObs->hasRange(rangeName);

// If no explicit range is given for RooAddPdf coefficients, create explicit named range equivalent to original observables range
if (!(addCoefRangeName && strlen(addCoefRangeName))) {
transferRangeAndBinning(*realObs, Form("NormalizationRangeFor%s",rangeName), nullptr);
}

// Adjust range of function observable to those of given named range
transferRangeAndBinning(*realObs, nullptr, rangeName);

// Adjust range of data observable to those of given named range
RooRealVar* dataObs = (RooRealVar*) dataObsSet->find(realObs->GetName()) ;
transferRangeAndBinning(*dataObs, nullptr, rangeName);

// Keep track of list of fit ranges in string attribute fit range of original p.d.f.
if (!_splitRange) {
// Keep track of list of fit ranges in string attribute fit range of
// original PDF, but only do so if we don't do multi-range
// normalization. In this case, we could not define an equivalend
// single fit range.
if (!_splitRange && strchr(rangeName,',') == 0) {
const std::string fitRangeName = std::string("fit_") + GetName();
const char* origAttrib = real.getStringAttribute("fitrange") ;
std::string newAttr = origAttrib ? origAttrib : "";
Expand All @@ -343,9 +345,6 @@ void RooAbsOptTestStatistic::initSlave(RooAbsReal& real, RooAbsData& indata, con
}
}
}

if (!observablesKnowRange)
coutW(Fitting) << "None of the fit observables seem to know the range '" << rangeName << "'. This means that the full range will be used." << std::endl;
}


Expand All @@ -366,10 +365,6 @@ void RooAbsOptTestStatistic::initSlave(RooAbsReal& real, RooAbsData& indata, con
cxcoutI(Fitting) << "RooAbsOptTestStatistic::ctor(" << GetName()
<< ") fixing interpretation of coefficients of any RooAddPdf component to range " << addCoefRangeName << endl ;
_funcClone->fixAddCoefRange(addCoefRangeName,false) ;
} else {
cxcoutI(Fitting) << "RooAbsOptTestStatistic::ctor(" << GetName()
<< ") fixing interpretation of coefficients of any RooAddPdf to full domain of observables " << endl ;
_funcClone->fixAddCoefRange(Form("NormalizationRangeFor%s",rangeName),false) ;
}
}

Expand Down Expand Up @@ -859,3 +854,4 @@ void RooAbsOptTestStatistic::setUpBinSampling() {
const char* RooAbsOptTestStatistic::cacheUniqueSuffix() const {
return Form("_%lx", _dataClone->uniqueId().value()) ;
}

85 changes: 3 additions & 82 deletions roofit/roofitcore/src/RooAbsPdf.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -938,51 +938,6 @@ RooAbsReal* RooAbsPdf::createNLL(RooAbsData& data, const RooCmdArg& arg1, const
return createNLL(data,l) ;
}

namespace {

std::unique_ptr<RooAbsReal> createMultiRangeNLLCorrectionTerm(
RooAbsPdf const &pdf, RooAbsData const &data, std::string const &baseName, std::string const &rangeNames)
{
double sumEntriesTotal = 0.0;

RooArgList termList;
RooArgList integralList;

for (const auto &currentRangeName : ROOT::Split(rangeNames, ",")) {
const std::string currentName = baseName + "_" + currentRangeName;

auto sumEntriesCurrent = data.sumEntries("1", currentRangeName.c_str());
sumEntriesTotal += sumEntriesCurrent;

RooArgSet depList;
pdf.getObservables(data.get(), depList);
auto pdfIntegralCurrent = pdf.createIntegral(depList, &depList, nullptr, currentRangeName.c_str());

auto term = new RooFormulaVar((currentName + "_correctionTerm").c_str(),
(std::string("-(") + std::to_string(sumEntriesCurrent) + " * log(x[0]))").c_str(),
RooArgList(*pdfIntegralCurrent));

termList.add(*term);
integralList.add(*pdfIntegralCurrent);
}

auto integralFull = new RooAddition((baseName + "_correctionFullIntegralTerm").c_str(),
"integral",
integralList,
true);

auto fullRangeTerm = new RooFormulaVar((baseName + "_foobar").c_str(),
(std::string("(") + std::to_string(sumEntriesTotal) + " * log(x[0]))").c_str(),
RooArgList(*integralFull));

termList.add(*fullRangeTerm);
return std::unique_ptr<RooAbsReal>{
new RooAddition((baseName + "_correction").c_str(), "correction", termList, true)};
}


} // namespace


////////////////////////////////////////////////////////////////////////////////
/// Construct representation of -log(L) of PDFwith given dataset. If dataset is unbinned, an unbinned likelihood is constructed. If the dataset
Expand Down Expand Up @@ -1129,43 +1084,9 @@ RooAbsReal* RooAbsPdf::createNLL(RooAbsData& data, const RooLinkedList& cmdList)
cfg.integrateOverBinsPrecision = pc.getDouble("IntegrateBins");
cfg.binnedL = false;
cfg.takeGlobalObservablesFromData = takeGlobalObservablesFromData;
if (!rangeName || strchr(rangeName,',')==0) {
// Simple case: default range, or single restricted range
//cout<<"FK: Data test 1: "<<data.sumEntries()<<endl;

cfg.rangeName = rangeName ? rangeName : "";
nll = std::make_unique<RooNLLVar>(baseName.c_str(),"-log(likelihood)",*this,data,projDeps, ext, cfg);
static_cast<RooNLLVar&>(*nll).batchMode(batchMode == RooFit::BatchModeOption::Old);
} else {
// Composite case: multiple ranges
RooArgList nllList ;
auto tokens = ROOT::Split(rangeName, ",");
if (RooHelpers::checkIfRangesOverlap(*this, data, tokens, cfg.splitCutRange)) {
throw std::runtime_error(
std::string("Error in RooAbsPdf::createNLL! The ranges ") + rangeName + " are overlapping!");
}
for (const auto& token : tokens) {
cfg.rangeName = token;
auto nllComp = std::make_unique<RooNLLVar>((baseName + "_" + token).c_str(),"-log(likelihood)",
*this,data,projDeps,ext,cfg);
nllComp->batchMode(pc.getInt("BatchMode"));
nllList.addOwned(std::move(nllComp)) ;
}

if (!ext) {
// Each RooNLLVar was created with the normalization set corresponding to
// the subrange, not the union range like it should be. We have to add an
// extra term to cancel this normalization problem. However, this is
// only necessarry for the non-extended case, because adding an extension
// term to the individual NLLs as done here is mathematicall equivalent
// to adding the normalization correction terms plus a global extension
// term.
nllList.addOwned(createMultiRangeNLLCorrectionTerm(*this, data, baseName, rangeName));
}

nll = std::make_unique<RooAddition>(baseName.c_str(),"-log(likelihood)",nllList) ;
nll->addOwnedComponents(std::move(nllList));
}
cfg.rangeName = rangeName ? rangeName : "";
nll = std::make_unique<RooNLLVar>(baseName.c_str(),"-log(likelihood)",*this,data,projDeps, ext, cfg);
static_cast<RooNLLVar&>(*nll).batchMode(batchMode == RooFit::BatchModeOption::Old);
RooAbsReal::setEvalErrorLoggingMode(RooAbsReal::PrintErrors) ;

// Include constraints, if any, in likelihood
Expand Down
18 changes: 7 additions & 11 deletions roofit/roofitcore/src/RooAbsTestStatistic.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ combined in the main thread.
#include "RooRealSumPdf.h"
#include "RooAbsCategoryLValue.h"

#include "ROOT/StringUtils.hxx"

#include "TTimeStamp.h"
#include "TClass.h"
#include <string>
Expand Down Expand Up @@ -107,16 +109,6 @@ RooAbsTestStatistic::RooAbsTestStatistic(const char *name, const char *title, Ro
{
// Register all parameters as servers
_paramSet.add(*std::unique_ptr<RooArgSet>{real.getParameters(&data)});

if (cfg.rangeName.find(',') != std::string::npos) {
auto errorMsg = std::string("Ranges ") + cfg.rangeName
+ " were passed to the RooAbsTestStatistic with name \"" + name + "\", "
+ "but it doesn't support multiple comma-separated fit ranges!\n" +
+ "Instead, one should combine multiple RooAbsTestStatistic objects "
+ "(see RooAbsPdf::createNLL for an example with RooNLLVar).";
coutE(InputArguments) << errorMsg << std::endl;
throw std::invalid_argument(errorMsg);
}
}


Expand Down Expand Up @@ -535,7 +527,11 @@ void RooAbsTestStatistic::initSimMode(RooSimultaneous* simpdf, RooAbsData* data,
cfg.integrateOverBinsPrecision = thisAsRooAbsOptTestStatistic->_integrateBinsPrecision;
}
if (_splitRange && !rangeName.empty()) {
cfg.rangeName = rangeName + "_" + catName;
auto tokens = ROOT::Split(rangeName, ",");
for(std::string const& token : tokens) {
cfg.rangeName += token + "_" + catName + ",";
}
cfg.rangeName.pop_back(); // to remove the last comma
cfg.nCPU = _nCPU*(_mpinterl?-1:1);
} else {
cfg.rangeName = rangeName;
Expand Down

0 comments on commit 9e61bfd

Please sign in to comment.