From 76938a5410a4ce07620828bee725b26189cbc8dd Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Thu, 10 Mar 2022 10:37:06 -0600 Subject: [PATCH] Added code to handle ConditionalTasks --- .../interface/ProductResolverIndexHelper.h | 4 + .../src/ProductResolverIndexHelper.cc | 26 ++ FWCore/Framework/interface/StreamSchedule.h | 15 ++ FWCore/Framework/src/StreamSchedule.cc | 224 +++++++++++++++++- FWCore/Framework/test/BuildFile.xml | 8 + .../test/test_conditionaltasks_cfg.py | 82 +++++++ 6 files changed, 354 insertions(+), 5 deletions(-) create mode 100644 FWCore/Framework/test/test_conditionaltasks_cfg.py diff --git a/DataFormats/Provenance/interface/ProductResolverIndexHelper.h b/DataFormats/Provenance/interface/ProductResolverIndexHelper.h index 73d20841bfa00..1b957e6bc85d9 100644 --- a/DataFormats/Provenance/interface/ProductResolverIndexHelper.h +++ b/DataFormats/Provenance/interface/ProductResolverIndexHelper.h @@ -84,6 +84,10 @@ namespace edm { // If the TypeID for the wrapped type is already available, // it is faster to call getContainedTypeFromWrapper directly. TypeID getContainedType(TypeID const& typeID); + + bool typeIsViewCompatible(TypeID const& requestedViewType, + TypeID const& wrappedtypeID, + std::string const& className); } // namespace productholderindexhelper class ProductResolverIndexHelper { diff --git a/DataFormats/Provenance/src/ProductResolverIndexHelper.cc b/DataFormats/Provenance/src/ProductResolverIndexHelper.cc index d47946fe1f85c..82221e5fa8a6b 100644 --- a/DataFormats/Provenance/src/ProductResolverIndexHelper.cc +++ b/DataFormats/Provenance/src/ProductResolverIndexHelper.cc @@ -70,6 +70,32 @@ namespace edm { TypeID const wrappedTypeID = TypeID(wrappedType.typeInfo()); return getContainedTypeFromWrapper(wrappedTypeID, className); } + + bool typeIsViewCompatible(TypeID const& requestedViewType, + TypeID const& wrappedtypeID, + std::string const& className) { + auto elementType = getContainedTypeFromWrapper(wrappedtypeID, className); + if (elementType == TypeID(typeid(void)) or elementType == TypeID()) { + //the wrapped type is not a container + return false; + } + if (elementType == requestedViewType) { + return true; + } + //need to check for inheritance match + std::vector missingDictionaries; + std::vector baseTypes; + if (!public_base_classes(missingDictionaries, elementType, baseTypes)) { + return false; + } + for (auto const& base : baseTypes) { + if (TypeID(base.typeInfo()) == requestedViewType) { + return true; + } + } + return false; + } + } // namespace productholderindexhelper ProductResolverIndexHelper::ProductResolverIndexHelper() diff --git a/FWCore/Framework/interface/StreamSchedule.h b/FWCore/Framework/interface/StreamSchedule.h index 640c6c0431313..4dedcaccc894d 100644 --- a/FWCore/Framework/interface/StreamSchedule.h +++ b/FWCore/Framework/interface/StreamSchedule.h @@ -288,6 +288,21 @@ namespace edm { void reportSkipped(EventPrincipal const& ep) const; + struct AliasInfo { + std::string friendlyClassName; + std::string instanceLabel; + std::string originalInstanceLabel; + std::string originalModuleLabel; + }; + std::vector tryToPlaceConditionalModules( + Worker*, + std::unordered_set& conditionalModules, + std::multimap const& conditionalModuleBranches, + std::multimap const& aliasMap, + ParameterSet& proc_pset, + ProductRegistry& preg, + PreallocationConfiguration const* prealloc, + std::shared_ptr processConfiguration); void fillWorkers(ParameterSet& proc_pset, ProductRegistry& preg, PreallocationConfiguration const* prealloc, diff --git a/FWCore/Framework/src/StreamSchedule.cc b/FWCore/Framework/src/StreamSchedule.cc index 342bdff7f242e..25eb198660517 100644 --- a/FWCore/Framework/src/StreamSchedule.cc +++ b/FWCore/Framework/src/StreamSchedule.cc @@ -3,6 +3,7 @@ #include "DataFormats/Provenance/interface/BranchIDListHelper.h" #include "DataFormats/Provenance/interface/ProcessConfiguration.h" #include "DataFormats/Provenance/interface/ProductRegistry.h" +#include "DataFormats/Provenance/interface/ProductResolverIndexHelper.h" #include "FWCore/Framework/src/OutputModuleDescription.h" #include "FWCore/Framework/interface/TriggerNamesService.h" #include "FWCore/Framework/src/TriggerReport.h" @@ -21,6 +22,7 @@ #include "FWCore/ParameterSet/interface/ParameterSetDescription.h" #include "FWCore/ParameterSet/interface/Registry.h" #include "FWCore/ServiceRegistry/interface/PathContext.h" +#include "FWCore/Reflection/interface/DictionaryTools.h" #include "FWCore/Utilities/interface/Algorithms.h" #include "FWCore/Utilities/interface/ConvertException.h" #include "FWCore/Utilities/interface/ExceptionCollector.h" @@ -36,6 +38,7 @@ #include #include #include +#include namespace edm { @@ -378,6 +381,154 @@ namespace edm { } } + static Worker* getWorker(std::string const& moduleLabel, + ParameterSet& proc_pset, + WorkerManager& workerManager, + ProductRegistry& preg, + PreallocationConfiguration const* prealloc, + std::shared_ptr processConfiguration) { + bool isTracked; + ParameterSet* modpset = proc_pset.getPSetForUpdate(moduleLabel, isTracked); + if (modpset == nullptr) { + return nullptr; + } + assert(isTracked); + + return workerManager.getWorker(*modpset, preg, prealloc, processConfiguration, moduleLabel); + } + + std::vector StreamSchedule::tryToPlaceConditionalModules( + Worker* worker, + std::unordered_set& conditionalModules, + std::multimap const& conditionalModuleBranches, + std::multimap const& aliasMap, + ParameterSet& proc_pset, + ProductRegistry& preg, + PreallocationConfiguration const* prealloc, + std::shared_ptr processConfiguration) { + std::vector returnValue; + auto const& consumesInfo = worker->consumesInfo(); + auto moduleLabel = worker->description()->moduleLabel(); + using namespace productholderindexhelper; + for (auto const& ci : consumesInfo) { + if (not ci.skipCurrentProcess() and + (ci.process().empty() or ci.process() == processConfiguration->processName())) { + auto productModuleLabel = ci.label(); + if (productModuleLabel.empty()) { + for (auto const& branch : conditionalModuleBranches) { + if (ci.kindOfType() == edm::PRODUCT_TYPE) { + if (branch.second->unwrappedTypeID() != ci.type()) { + continue; + } + } else { + if (not typeIsViewCompatible( + ci.type(), TypeID(branch.second->wrappedType().typeInfo()), branch.second->className())) { + continue; + } + } + + auto condWorker = + getWorker(productModuleLabel, proc_pset, workerManager_, preg, prealloc, processConfiguration); + assert(condWorker); + + conditionalModules.erase(productModuleLabel); + + auto dependents = tryToPlaceConditionalModules(condWorker, + conditionalModules, + conditionalModuleBranches, + aliasMap, + proc_pset, + preg, + prealloc, + processConfiguration); + returnValue.insert(returnValue.end(), dependents.begin(), dependents.end()); + returnValue.push_back(condWorker); + } + } else { + //just a regular consumes + bool productFromConditionalModule = false; + auto itFound = conditionalModules.find(productModuleLabel); + if (itFound == conditionalModules.end()) { + //Check to see if this was an alias + auto findAlias = aliasMap.equal_range(productModuleLabel); + if (findAlias.first != findAlias.second) { + for (auto it = findAlias.first; it != findAlias.second; ++it) { + //this was previously filtered so only the conditional modules remain + productModuleLabel = it->second.originalModuleLabel; + if (it->second.friendlyClassName == "*" or + (ci.type().friendlyClassName() == it->second.friendlyClassName)) { + if (it->second.instanceLabel == "*" or ci.instance() == it->second.instanceLabel) { + productFromConditionalModule = true; + //need to check the rest of the data product info + break; + } + } else if (ci.kindOfType() == ELEMENT_TYPE) { + //consume is a View so need to do more intrusive search + if (it->second.instanceLabel == "*" or ci.instance() == it->second.instanceLabel) { + //find matching branches in module + auto branches = conditionalModuleBranches.equal_range(productModuleLabel); + for (auto itBranch = branches.first; itBranch != branches.second; ++it) { + if (it->second.originalInstanceLabel == "*" or + itBranch->second->productInstanceName() == it->second.originalInstanceLabel) { + if (typeIsViewCompatible(ci.type(), + TypeID(itBranch->second->wrappedType().typeInfo()), + itBranch->second->className())) { + productFromConditionalModule = true; + break; + } + } + } + } + } + } + } + itFound = conditionalModules.find(productModuleLabel); + } else { + //need to check the rest of the data product info + auto findBranches = conditionalModuleBranches.equal_range(productModuleLabel); + for (auto itBranch = findBranches.first; itBranch != findBranches.second; ++itBranch) { + if (itBranch->second->productInstanceName() == ci.instance()) { + if (ci.kindOfType() == PRODUCT_TYPE) { + if (ci.type() == itBranch->second->unwrappedTypeID()) { + productFromConditionalModule = true; + break; + } + } else { + //this is a view + if (typeIsViewCompatible(ci.type(), + TypeID(itBranch->second->wrappedType().typeInfo()), + itBranch->second->className())) { + productFromConditionalModule = true; + break; + } + } + } + } + } + if (productFromConditionalModule) { + auto condWorker = + getWorker(productModuleLabel, proc_pset, workerManager_, preg, prealloc, processConfiguration); + assert(condWorker); + + conditionalModules.erase(itFound); + + auto dependents = tryToPlaceConditionalModules(condWorker, + conditionalModules, + conditionalModuleBranches, + aliasMap, + proc_pset, + preg, + prealloc, + processConfiguration); + returnValue.insert(returnValue.end(), dependents.begin(), dependents.end()); + returnValue.push_back(condWorker); + } + } + } + } + return returnValue; + } + void StreamSchedule::fillWorkers(ParameterSet& proc_pset, ProductRegistry& preg, PreallocationConfiguration const* prealloc, @@ -389,6 +540,61 @@ namespace edm { vstring modnames = proc_pset.getParameter(pathName); PathWorkers tmpworkers; + //Pull out ConditionalTask modules + auto itCondBegin = std::find(modnames.begin(), modnames.end(), "#"); + + std::unordered_set conditionalmods; + //need to capture + std::multimap aliasMap; + std::multimap conditionalModsBranches; + if (itCondBegin != modnames.end()) { + //the last entry should be ignored since it is required to be "@" + conditionalmods = std::unordered_set( + std::make_move_iterator(itCondBegin + 1), std::make_move_iterator(modnames.begin() + modnames.size() - 1)); + + for (auto const& cond : conditionalmods) { + //force the creation of the conditional modules so alias check can work + (void)getWorker(cond, proc_pset, workerManager_, preg, prealloc, processConfiguration); + } + //find aliases + { + auto aliases = proc_pset.getParameter>("@all_aliases"); + std::string const star("*"); + for (auto const& alias : aliases) { + auto info = proc_pset.getParameter(alias); + auto aliasedToModuleLabels = info.getParameterNames(); + for (auto const& mod : aliasedToModuleLabels) { + if (not mod.empty() and mod[0] != '@' and conditionalmods.find(mod) != conditionalmods.end()) { + auto aliasPSet = proc_pset.getParameter(mod); + std::string type = star; + std::string instance = star; + std::string originalInstance = star; + if (aliasPSet.exists("type")) { + type = aliasPSet.getParameter("type"); + } + if (aliasPSet.exists("toProductInstance")) { + instance = aliasPSet.getParameter("toProductInstance"); + } + if (aliasPSet.exists("fromProductInstance")) { + originalInstance = aliasPSet.getParameter("fromProductInstance"); + } + + aliasMap.emplace(alias, AliasInfo{type, instance, originalInstance, mod}); + } + } + } + } + { + //find branches created by the conditional modules + for (auto const& prod : preg.productList()) { + if (conditionalmods.find(prod.first.moduleLabel()) != conditionalmods.end()) { + conditionalModsBranches.emplace(prod.first.moduleLabel(), &prod.second); + } + } + } + } + modnames.erase(itCondBegin, modnames.end()); + unsigned int placeInPath = 0; for (auto const& name : modnames) { //Modules except EDFilters are set to run concurrently by default @@ -409,9 +615,8 @@ namespace edm { moduleLabel.erase(0, 1); } - bool isTracked; - ParameterSet* modpset = proc_pset.getPSetForUpdate(moduleLabel, isTracked); - if (modpset == nullptr) { + Worker* worker = getWorker(moduleLabel, proc_pset, workerManager_, preg, prealloc, processConfiguration); + if (worker == nullptr) { std::string pathType("endpath"); if (!search_all(endPathNames, pathName)) { pathType = std::string("path"); @@ -420,9 +625,7 @@ namespace edm { << "The unknown module label \"" << moduleLabel << "\" appears in " << pathType << " \"" << pathName << "\"\n please check spelling or remove that label from the path."; } - assert(isTracked); - Worker* worker = workerManager_.getWorker(*modpset, preg, prealloc, processConfiguration, moduleLabel); if (ignoreFilters && filterAction != WorkerInPath::Ignore && worker->moduleType() == Worker::kFilter) { // We have a filter on an end path, and the filter is not explicitly ignored. // See if the filter is allowed. @@ -442,6 +645,17 @@ namespace edm { if (runConcurrently && worker->moduleType() == Worker::kFilter and filterAction != WorkerInPath::Ignore) { runConcurrently = false; } + + //TODO: call consumesInfo and see if need any modules from conditionalmods + // call module's typeLabelList function to see what it produces + // consume many has blank module label so need to check type -> what about Views? + auto condModules = tryToPlaceConditionalModules( + worker, conditionalmods, conditionalModsBranches, aliasMap, proc_pset, preg, prealloc, processConfiguration); + for (auto condMod : condModules) { + tmpworkers.emplace_back(condMod, WorkerInPath::Ignore, placeInPath, true); + ++placeInPath; + } + tmpworkers.emplace_back(worker, filterAction, placeInPath, runConcurrently); ++placeInPath; } diff --git a/FWCore/Framework/test/BuildFile.xml b/FWCore/Framework/test/BuildFile.xml index fbdfec87b1812..a14dbe7358c35 100644 --- a/FWCore/Framework/test/BuildFile.xml +++ b/FWCore/Framework/test/BuildFile.xml @@ -404,4 +404,12 @@ + + + + + + + + diff --git a/FWCore/Framework/test/test_conditionaltasks_cfg.py b/FWCore/Framework/test/test_conditionaltasks_cfg.py new file mode 100644 index 0000000000000..a1e92eb31e05f --- /dev/null +++ b/FWCore/Framework/test/test_conditionaltasks_cfg.py @@ -0,0 +1,82 @@ +import FWCore.ParameterSet.Config as cms + +import argparse +import sys + +parser = argparse.ArgumentParser(prog=sys.argv[0], description='Test ConditionalTasks.') + +parser.add_argument("--filterSucceeds", help="Have filter succeed", action="store_true") +parser.add_argument("--reverseDependencies", help="Switch the order of dependencies", action="store_true") +parser.add_argument("--testAlias", help="Get data from an alias", action="store_true") +parser.add_argument("--testView", help="Get data via a view", action="store_true") +parser.add_argument("--aliasWithStar", help="when using testAlias use '*' as type", action="store_true") + +argv = sys.argv[:] +if '--' in argv: + argv.remove("--") +args, unknown = parser.parse_known_args(argv) + +process = cms.Process("Test") + +process.source = cms.Source("EmptySource") + +process.maxEvents.input = 1 + +process.a = cms.EDProducer("IntProducer", ivalue = cms.int32(1)) +process.b = cms.EDProducer("AddIntsProducer", labels = cms.VInputTag(cms.InputTag("a"))) + +process.f1 = cms.EDFilter("IntProductFilter", label = cms.InputTag("b")) + +process.c = cms.EDProducer("IntProducer", ivalue = cms.int32(2)) +process.d = cms.EDProducer("AddIntsProducer", labels = cms.VInputTag(cms.InputTag("c"))) +process.e = cms.EDProducer("AddIntsProducer", labels = cms.VInputTag(cms.InputTag("d"))) + +process.prodOnPath = cms.EDProducer("AddIntsProducer", labels = cms.VInputTag(cms.InputTag("d"), cms.InputTag("e"))) + +if args.filterSucceeds: + threshold = 1 +else: + threshold = 3 + +process.f2 = cms.EDFilter("IntProductFilter", label = cms.InputTag("e"), threshold = cms.int32(threshold)) + +if args.reverseDependencies: + process.d.labels[0]=cms.InputTag("e") + process.e.labels[0]=cms.InputTag("c") + process.f2.label = cms.InputTag("d") + +if args.testView: + process.f3 = cms.EDAnalyzer("SimpleViewAnalyzer", + label = cms.untracked.InputTag("f"), + sizeMustMatch = cms.untracked.uint32(10), + checkSize = cms.untracked.bool(False) + ) + process.f = cms.EDProducer("OVSimpleProducer", size = cms.int32(10)) + producttype = "edmtestSimplesOwned" +else: + process.f= cms.EDProducer("IntProducer", ivalue = cms.int32(3)) + process.f3 = cms.EDFilter("IntProductFilter", label = cms.InputTag("f")) + producttype = "edmtestIntProduct" + +if args.testAlias: + if args.aliasWithStar: + producttype = "*" + + process.f3.label = "aliasToF" + process.aliasToF = cms.EDAlias( + f = cms.VPSet( + cms.PSet( + type = cms.string(producttype), + ) + ) + ) + + +process.p = cms.Path(process.f1+process.prodOnPath+process.f2+process.f3, cms.ConditionalTask(process.a, process.b, process.c, process.d, process.e, process.f)) + +process.tst = cms.EDAnalyzer("IntTestAnalyzer", moduleLabel = cms.untracked.InputTag("f"), valueMustMatch = cms.untracked.int32(3), + valueMustBeMissing = cms.untracked.bool(not args.filterSucceeds)) + +process.endp = cms.EndPath(process.tst) + +#process.add_(cms.Service("Tracer"))