From bc718678e1f599bcbbfc3b5cf91ccf7cd42376cb Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Sun, 18 Oct 2020 10:42:05 -0500 Subject: [PATCH] Added PSetTemplate to allow description of a PSet This is used in conjunction with required or optional functionality to describe the allowed parameters of a PSet. --- FWCore/ParameterSet/python/Config.py | 6 ++- FWCore/ParameterSet/python/Mixins.py | 4 ++ FWCore/ParameterSet/python/Types.py | 73 +++++++++++++++++++++++++++- 3 files changed, 79 insertions(+), 4 deletions(-) diff --git a/FWCore/ParameterSet/python/Config.py b/FWCore/ParameterSet/python/Config.py index 218dc66ec151d..60e12a2f127e3 100644 --- a/FWCore/ParameterSet/python/Config.py +++ b/FWCore/ParameterSet/python/Config.py @@ -1990,7 +1990,9 @@ def __init__(self,*arg,**args): def testProcessDumpPython(self): self.assertEqual(Process("test").dumpPython(), -"""import FWCore.ParameterSet.Config as cms\n\nprocess = cms.Process("test") +"""import FWCore.ParameterSet.Config as cms + +process = cms.Process("test") process.maxEvents = cms.untracked.PSet( input = cms.optional.untracked.int32, @@ -2011,7 +2013,7 @@ def testProcessDumpPython(self): emptyRunLumiMode = cms.obsolete.untracked.string, eventSetup = cms.untracked.PSet( forceNumberOfConcurrentIOVs = cms.untracked.PSet( - + allowAnyLabel_=cms.required.untracked.uint32 ), numberOfConcurrentIOVs = cms.untracked.uint32(1) ), diff --git a/FWCore/ParameterSet/python/Mixins.py b/FWCore/ParameterSet/python/Mixins.py index 390adf5812e07..aa53cf042104d 100644 --- a/FWCore/ParameterSet/python/Mixins.py +++ b/FWCore/ParameterSet/python/Mixins.py @@ -364,6 +364,10 @@ def dumpPython(self, options=PrintOptions()): # usings need to go first resultList = usings resultList.extend(others) + if self.__validator is not None: + options.indent() + resultList.append(options.indentation()+"allowAnyLabel_="+self.__validator.dumpPython(options)) + options.unindent() return ',\n'.join(resultList)+'\n' def __repr__(self): return self.dumpPython() diff --git a/FWCore/ParameterSet/python/Types.py b/FWCore/ParameterSet/python/Types.py index 6b9c9b3156d52..f2f857e4a85be 100644 --- a/FWCore/ParameterSet/python/Types.py +++ b/FWCore/ParameterSet/python/Types.py @@ -59,6 +59,8 @@ def __setattr__(self,name, value): if v is not None: return setattr(v,name,value) else: + if not name.startswith('_'): + raise AttributeError("%r object has no attribute %r" % (self.__class__.__name__, name)) return object.__setattr__(self, name, value) def __bool__(self): v = self.__dict__.get('_ProxyParameter__value',None) @@ -71,7 +73,9 @@ def dumpPython(self, options=PrintOptions()): v = "cms."+self._dumpPythonName() if not _ParameterTypeBase.isTracked(self): v+=".untracked" - return v+'.'+self.__type.__name__ + if hasattr(self.__type, "__name__"): + return v+'.'+self.__type.__name__ + return v+'.'+self.__type.dumpPython(options) def validate_(self,value): return isinstance(value,self.__type) def convert_(self,value): @@ -138,6 +142,19 @@ def __call__(self,value): raise RuntimeError("Cannot convert "+str(value)+" to 'allowed' type") return chosenType(value) +class _PSetTemplate(object): + def __init__(self, *args, **kargs): + self._pset = PSet(*args,**kargs) + self.__dict__['_PSetTemplate__value'] = None + def __call__(self, value): + self.__dict__ + return self._pset.clone(**value) + def dumpPython(self, options=PrintOptions()): + v =self.__dict__.get('_ProxyParameter__value',None) + if v is not None: + return v.dumpPython(options) + return "PSetTemplate(\n"+_Parameterizable.dumpPython(self._pset, options)+options.indentation()+")" + class _ProxyParameterFactory(object): """Class type for ProxyParameter types to allow nice syntax""" @@ -160,7 +177,17 @@ def __call__(self, *args): return self.type(_AllowedParameterTypes(*args)) return _AllowedWrapper(self.__isUntracked, self.__type) - + if name == 'PSetTemplate': + class _PSetTemplateWrapper(object): + def __init__(self, untracked, type): + self.untracked = untracked + self.type = type + def __call__(self,*args,**kargs): + if self.untracked: + return untracked(self.type(_PSetTemplate(*args,**kargs))) + return self.type(_PSetTemplate(*args,**kargs)) + return _PSetTemplateWrapper(self.__isUntracked, self.__type) + type = globals()[name] if not issubclass(type, _ParameterTypeBase): raise AttributeError @@ -1859,6 +1886,27 @@ def testRequired(self): self.assertEqual(p1.foo.value(),3) self.failIf(p1.foo.isTracked()) self.assertRaises(ValueError,setattr,p1, 'bar', 'bad') + #PSetTemplate use + p1 = PSet(aPSet = required.PSetTemplate()) + self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.required.PSetTemplate(\n\n )\n)') + p1.aPSet = dict() + self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.PSet(\n\n )\n)') + p1 = PSet(aPSet=required.PSetTemplate(a=required.int32)) + self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.required.PSetTemplate(\n a = cms.required.int32\n )\n)') + p1.aPSet = dict(a=5) + self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.PSet(\n a = cms.int32(5)\n )\n)') + self.assertEqual(p1.aPSet.a.value(), 5) + p1 = PSet(aPSet=required.untracked.PSetTemplate(a=required.int32)) + self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.required.untracked.PSetTemplate(\n a = cms.required.int32\n )\n)') + p1.aPSet = dict(a=5) + self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.untracked.PSet(\n a = cms.int32(5)\n )\n)') + self.assertEqual(p1.aPSet.a.value(), 5) + p1 = PSet(allowAnyLabel_=required.PSetTemplate(a=required.int32)) + self.assertEqual(p1.dumpPython(), 'cms.PSet(\n allowAnyLabel_=cms.required.PSetTemplate(\n a = cms.required.int32\n )\n)') + p1.foo = dict(a=5) + self.assertEqual(p1.dumpPython(),'cms.PSet(\n foo = cms.PSet(\n a = cms.int32(5)\n ),\n allowAnyLabel_=cms.required.PSetTemplate(\n a = cms.required.int32\n )\n)') + self.assertEqual(p1.foo.a.value(), 5) + def testOptional(self): p1 = PSet(anInt = optional.int32) self.assert_(hasattr(p1,"anInt")) @@ -1887,6 +1935,27 @@ def testOptional(self): self.failIf(p1.f) p1.f.append(3) self.assert_(p1.f) + #PSetTemplate use + p1 = PSet(aPSet = optional.PSetTemplate()) + self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.optional.PSetTemplate(\n\n )\n)') + p1.aPSet = dict() + self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.PSet(\n\n )\n)') + p1 = PSet(aPSet=optional.PSetTemplate(a=optional.int32)) + self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.optional.PSetTemplate(\n a = cms.optional.int32\n )\n)') + p1.aPSet = dict(a=5) + self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.PSet(\n a = cms.int32(5)\n )\n)') + self.assertEqual(p1.aPSet.a.value(), 5) + p1 = PSet(aPSet=optional.untracked.PSetTemplate(a=optional.int32)) + self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.optional.untracked.PSetTemplate(\n a = cms.optional.int32\n )\n)') + p1.aPSet = dict(a=5) + self.assertEqual(p1.dumpPython(),'cms.PSet(\n aPSet = cms.untracked.PSet(\n a = cms.int32(5)\n )\n)') + self.assertEqual(p1.aPSet.a.value(), 5) + p1 = PSet(allowAnyLabel_=optional.PSetTemplate(a=optional.int32)) + self.assertEqual(p1.dumpPython(), 'cms.PSet(\n allowAnyLabel_=cms.optional.PSetTemplate(\n a = cms.optional.int32\n )\n)') + p1.foo = dict(a=5) + self.assertEqual(p1.dumpPython(),'cms.PSet(\n foo = cms.PSet(\n a = cms.int32(5)\n ),\n allowAnyLabel_=cms.optional.PSetTemplate(\n a = cms.optional.int32\n )\n)') + self.assertEqual(p1.foo.a.value(), 5) + def testAllowed(self): p1 = PSet(aValue = required.allowed(int32, string))