diff --git a/FWCore/ParameterSet/python/Config.py b/FWCore/ParameterSet/python/Config.py index c6d198ba04ed6..e26c65c7c0be6 100644 --- a/FWCore/ParameterSet/python/Config.py +++ b/FWCore/ParameterSet/python/Config.py @@ -1093,6 +1093,11 @@ def globalReplace(self,label,new): if not hasattr(self,label): raise LookupError("process has no item of label "+label) setattr(self,label,new) + def setSwitchProducerCaseForAll(self, switchProducerType, case): + """Set the chosen case to 'case' for all SwitchProducers ot type 'SwitchProducerType'""" + for sp in self.__switchproducers.values(): + if sp.__class__.__name__ == switchProducerType: + sp.setCase_(case) def _insertInto(self, parameterSet, itemDict): for name,value in itemDict.items(): value.insertInto(parameterSet, name) @@ -1850,6 +1855,16 @@ def __init__(self, **kargs): ), **kargs) specialImportRegistry.registerSpecialImportForType(SwitchProducerTest, "from test import SwitchProducerTest") + class SwitchProducerTest2(SwitchProducer): + def __init__(self, **kargs): + super(SwitchProducerTest2,self).__init__( + dict( + test1 = lambda: (True, -7), + test2 = lambda: (True, -5), + test30 = lambda: (True, -10), + ), **kargs) + specialImportRegistry.registerSpecialImportForType(SwitchProducerTest2, "from test import SwitchProducerTest2") + class TestModuleCommand(unittest.TestCase): def setUp(self): """Nothing to do """ @@ -3069,6 +3084,50 @@ def testSwitchProducer(self): self.assertEqual((True,"EDAlias"), p.values["sp@test2"][1].values["@module_edm_type"]) self.assertEqual((True,"Bar"), p.values["sp@test2"][1].values["a"][1][0].values["type"]) + # Forcing the choice + proc = Process("test") + proc.sp1 = SwitchProducerTest(test1 = EDProducer("Foo1"), test3 = EDProducer("Fred1")) + proc.sp2 = SwitchProducerTest(test1 = EDProducer("Foo2"), test2 = EDProducer("Bar2"), test3 = EDProducer("Fred2")) + proc.sp10 = SwitchProducerTest2(test1 = EDProducer("Foo10"), test2 = EDProducer("Bar10"), test30 = EDProducer("Wilma10")) + proc.t = Task(proc.sp1, proc.sp2, proc.sp10) + proc.p = Path(proc.t) + self.assertEqual(proc.sp1._getProducer().type_(), "Fred1") + self.assertEqual(proc.sp2._getProducer().type_(), "Fred2") + self.assertEqual(proc.sp10._getProducer().type_(), "Bar10") + proc.setSwitchProducerCaseForAll("SwitchProducerTest", "test1") + self.assertEqual(proc.sp1._getProducer().type_(), "Foo1") + self.assertEqual(proc.sp2._getProducer().type_(), "Foo2") + self.assertEqual(proc.sp10._getProducer().type_(), "Bar10") + proc.setSwitchProducerCaseForAll("SwitchProducerTest2", "test30") + self.assertEqual(proc.sp1._getProducer().type_(), "Foo1") + self.assertEqual(proc.sp2._getProducer().type_(), "Foo2") + self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10") + proc.setSwitchProducerCaseForAll("SwitchProducerTest", "test2") + self.assertRaises(RuntimeError, proc.sp1._getProducer) + self.assertEqual(proc.sp2._getProducer().type_(), "Bar2") + self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10") + proc.setSwitchProducerCaseForAll("SwitchProducerTest", None) + self.assertEqual(proc.sp1._getProducer().type_(), "Fred1") + self.assertEqual(proc.sp2._getProducer().type_(), "Fred2") + self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10") + p = TestMakePSet() + proc.fillProcessDesc(p) + self.assertEqual((False, "sp1@test3"), p.values["sp1"][1].values["@chosen_case"]) + self.assertEqual((False, "sp2@test3"), p.values["sp2"][1].values["@chosen_case"]) + self.assertEqual((False, "sp10@test30"), p.values["sp10"][1].values["@chosen_case"]) + proc.setSwitchProducerCaseForAll("SwitchProducerTest", "test2") + p = TestMakePSet() + self.assertRaises(RuntimeError, proc.fillProcessDesc, p) + proc.sp1.setCase_("test1") + self.assertEqual(proc.sp1._getProducer().type_(), "Foo1") + self.assertEqual(proc.sp2._getProducer().type_(), "Bar2") + self.assertEqual(proc.sp10._getProducer().type_(), "Wilma10") + p = TestMakePSet() + proc.fillProcessDesc(p) + self.assertEqual((False, "sp1@test1"), p.values["sp1"][1].values["@chosen_case"]) + self.assertEqual((False, "sp2@test2"), p.values["sp2"][1].values["@chosen_case"]) + self.assertEqual((False, "sp10@test30"), p.values["sp10"][1].values["@chosen_case"]) + def testPrune(self): p = Process("test") p.a = EDAnalyzer("MyAnalyzer")