Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added python type annotation to some configuration classes #44453

Merged
merged 1 commit into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 86 additions & 85 deletions FWCore/ParameterSet/python/Config.py

Large diffs are not rendered by default.

111 changes: 56 additions & 55 deletions FWCore/ParameterSet/python/Mixins.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from __future__ import print_function
from builtins import range, object
import inspect
from typing import Union

class _ConfigureComponent(object):
"""Denotes a class that can be used by the Processes class"""
def _isTaskComponent(self):
def _isTaskComponent(self) -> bool:
return False

class PrintOptions(object):
def __init__(self, indent = 0, deltaIndent = 4, process = True, targetDirectory = None, useSubdirectories = False):
def __init__(self, indent:int = 0, deltaIndent:int = 4, process:bool = True, targetDirectory: Union[str, None] = None, useSubdirectories:bool = False):
self.indent_= indent
self.deltaIndent_ = deltaIndent
self.isCfg = process
self.targetDirectory = targetDirectory
self.useSubdirectories = useSubdirectories
def indentation(self):
def indentation(self) -> str:
return ' '*self.indent_
def indent(self):
self.indent_ += self.deltaIndent_
Expand Down Expand Up @@ -58,32 +59,32 @@ def __init__(self):
self.__dict__["_isFrozen"] = False
self.__isTracked = True
self._isModified = False
def isModified(self):
def isModified(self) -> bool:
return self._isModified
def resetModified(self):
self._isModified=False
def configTypeName(self):
def configTypeName(self) -> str:
if self.isTracked():
return type(self).__name__
return 'untracked '+type(self).__name__
def pythonTypeName(self):
def pythonTypeName(self) -> str:
if self.isTracked():
return 'cms.'+type(self).__name__
return 'cms.untracked.'+type(self).__name__
def dumpPython(self, options=PrintOptions()):
def dumpPython(self, options:PrintOptions=PrintOptions()) -> str:
specialImportRegistry.registerUse(self)
return self.pythonTypeName()+"("+self.pythonValue(options)+")"
def __repr__(self):
def __repr__(self) -> str:
return self.dumpPython()
def isTracked(self):
def isTracked(self) -> bool:
return self.__isTracked
def setIsTracked(self,trackness):
def setIsTracked(self,trackness:bool):
self.__isTracked = trackness
def isFrozen(self):
def isFrozen(self) -> bool:
return self._isFrozen
def setIsFrozen(self):
self._isFrozen = True
def isCompatibleCMSType(self,aType):
def isCompatibleCMSType(self,aType) -> bool:
return isinstance(self,aType)
def _checkAndReturnValueWithType(self, valueWithType):
if isinstance(valueWithType, type(self)):
Expand All @@ -106,31 +107,31 @@ def setValue(self,value):
if value!=self._value:
self._isModified=True
self._value=value
def configValue(self, options=PrintOptions()):
def configValue(self, options:PrintOptions=PrintOptions()) -> str:
return str(self._value)
def pythonValue(self, options=PrintOptions()):
def pythonValue(self, options:PrintOptions=PrintOptions()) -> str:
return self.configValue(options)
def __eq__(self,other):
def __eq__(self,other) -> bool:
if isinstance(other,_SimpleParameterTypeBase):
return self._value == other._value
return self._value == other
def __ne__(self,other):
def __ne__(self,other) -> bool:
if isinstance(other,_SimpleParameterTypeBase):
return self._value != other._value
return self._value != other
def __lt__(self,other):
def __lt__(self,other) -> bool:
if isinstance(other,_SimpleParameterTypeBase):
return self._value < other._value
return self._value < other
def __le__(self,other):
def __le__(self,other) -> bool:
if isinstance(other,_SimpleParameterTypeBase):
return self._value <= other._value
return self._value <= other
def __gt__(self,other):
def __gt__(self,other) -> bool:
if isinstance(other,_SimpleParameterTypeBase):
return self._value > other._value
return self._value > other
def __ge__(self,other):
def __ge__(self,other) -> bool:
if isinstance(other,_SimpleParameterTypeBase):
return self._value >= other._value
return self._value >= other
Expand All @@ -140,25 +141,25 @@ class UsingBlock(_SimpleParameterTypeBase):
"""For injection purposes, pretend this is a new parameter type
then have a post process step which strips these out
"""
def __init__(self,value, s='', loc=0, file=''):
def __init__(self,value, s:str='', loc:int=0, file:str=''):
super(UsingBlock,self).__init__(value)
self.s = s
self.loc = loc
self.file = file
self.isResolved = False
@staticmethod
def _isValid(value):
def _isValid(value) -> bool:
return isinstance(value,str)
def _valueFromString(value):
def _valueFromString(value) -> str:
"""only used for cfg-parsing"""
return string(value)
def insertInto(self, parameterSet, myname):
return str(value)
def insertInto(self, parameterSet, myname:str):
value = self.value()
# doesn't seem to handle \0 correctly
#if value == '\0':
# value = ''
parameterSet.addString(self.isTracked(), myname, value)
def dumpPython(self, options=PrintOptions()):
def dumpPython(self, options:PrintOptions=PrintOptions()) -> str:
if options.isCfg:
return "process."+self.value()
else:
Expand Down Expand Up @@ -188,7 +189,7 @@ def __init__(self,*arg,**kargs):
def parameterNames_(self):
"""Returns the name of the parameters"""
return self.__parameterNames[:]
def isModified(self):
def isModified(self) -> bool:
if self._isModified:
return True
for name in self.parameterNames_():
Expand All @@ -198,7 +199,7 @@ def isModified(self):
return True
return False

def hasParameter(self, params):
def hasParameter(self, params) -> bool:
"""
_hasParameter_

Expand Down Expand Up @@ -239,7 +240,7 @@ def parameters_(self):
result[name]=copy.deepcopy(self.__dict__[name])
return result

def __addParameter(self, name, value):
def __addParameter(self, name:str, value):
if name == 'allowAnyLabel_':
self.__validator = value
self._isModified = True
Expand Down Expand Up @@ -267,7 +268,7 @@ def __setParameters(self,parameters):
self.__addParameter(name, value)
if v is not None:
self.__validator=v
def __setattr__(self,name,value):
def __setattr__(self,name:str,value):
#since labels are not supposed to have underscores at the beginning
# I will assume that if we have such then we are setting an internal variable
if self.isFrozen() and not (name in ["_Labelable__label","_isFrozen"] or name.startswith('_')):
Expand All @@ -290,21 +291,21 @@ def __setattr__(self,name,value):
self.__dict__[name].setValue(value)
self._isModified = True

def isFrozen(self):
def isFrozen(self) -> bool:
return self._isFrozen
def setIsFrozen(self):
self._isFrozen = True
for name in self.parameterNames_():
self.__dict__[name].setIsFrozen()
def __delattr__(self,name):
def __delattr__(self,name:str):
if self.isFrozen():
raise ValueError("Object already added to a process. It is read only now")
super(_Parameterizable,self).__delattr__(name)
self.__parameterNames.remove(name)
@staticmethod
def __raiseBadSetAttr(name):
def __raiseBadSetAttr(name:str):
raise TypeError(name+" does not already exist, so it can only be set to a CMS python configuration type")
def dumpPython(self, options=PrintOptions()):
def dumpPython(self, options:PrintOptions=PrintOptions()) -> str:
specialImportRegistry.registerUse(self)
sortedNames = sorted(self.parameterNames_())
if len(sortedNames) > 200:
Expand Down Expand Up @@ -376,7 +377,7 @@ def dumpPython(self, options=PrintOptions()):
resultList.append(options.indentation()+"allowAnyLabel_="+self.__validator.dumpPython(options))
options.unindent()
return ',\n'.join(resultList)+'\n'
def __repr__(self):
def __repr__(self) -> str:
return self.dumpPython()
def insertContentsInto(self, parameterSet):
for name in self.parameterNames_():
Expand All @@ -395,7 +396,7 @@ def __init__(self,type_,*arg,**kargs):
# del args['type_']
super(_TypedParameterizable,self).__init__(*arg,**kargs)
saveOrigin(self, 1)
def _place(self,name,proc):
def _place(self,name:str,proc):
self._placeImpl(name,proc)
def type_(self):
"""returns the type of the object, e.g. 'FooProducer'"""
Expand Down Expand Up @@ -444,7 +445,7 @@ def clone(self, *args, **params):
return returnValue

@staticmethod
def __findDefaultsFor(label,type):
def __findDefaultsFor(label:str,type):
#This routine is no longer used, but I might revive it in the future
import sys
import glob
Expand Down Expand Up @@ -477,7 +478,7 @@ def __findDefaultsFor(label,type):
def directDependencies(self):
return []

def dumpConfig(self, options=PrintOptions()):
def dumpConfig(self, options:PrintOptions=PrintOptions()) -> str:
config = self.__type +' { \n'
for name in self.parameterNames_():
param = self.__dict__[name]
Expand All @@ -487,7 +488,7 @@ def dumpConfig(self, options=PrintOptions()):
config += options.indentation()+'}\n'
return config

def dumpPython(self, options=PrintOptions()):
def dumpPython(self, options:PrintOptions=PrintOptions()) -> str:
specialImportRegistry.registerUse(self)
result = "cms."+str(type(self).__name__)+'("'+self.type_()+'"'
nparam = len(self.parameterNames_())
Expand All @@ -497,21 +498,21 @@ def dumpPython(self, options=PrintOptions()):
result += ",\n"+_Parameterizable.dumpPython(self,options)+options.indentation() + ")\n"
return result

def dumpPythonAttributes(self, myname, options):
def dumpPythonAttributes(self, myname:str, options:PrintOptions) -> str:
""" dumps the object with all attributes declared after the constructor"""
result = ""
for name in sorted(self.parameterNames_()):
param = self.__dict__[name]
result += options.indentation() + myname + "." + name + " = " + param.dumpPython(options) + "\n"
return result

def nameInProcessDesc_(self, myname):
def nameInProcessDesc_(self, myname:str):
return myname;
def moduleLabel_(self, myname):
def moduleLabel_(self, myname:str):
return myname
def appendToProcessDescList_(self, lst, myname):
def appendToProcessDescList_(self, lst, myname:str):
lst.append(self.nameInProcessDesc_(myname))
def insertInto(self, parameterSet, myname):
def insertInto(self, parameterSet, myname:str):
newpset = parameterSet.newPSet()
newpset.addString(True, "@module_label", self.moduleLabel_(myname))
newpset.addString(True, "@module_type", self.type_())
Expand All @@ -523,13 +524,13 @@ def insertInto(self, parameterSet, myname):

class _Labelable(object):
"""A 'mixin' used to denote that the class can be paired with a label (e.g. an EDProducer)"""
def label_(self):
def label_(self) -> str:
if not hasattr(self, "_Labelable__label"):
raise RuntimeError("module has no label. Perhaps it wasn't inserted into the process?")
return self.__label
def hasLabel_(self):
def hasLabel_(self) -> bool:
return hasattr(self, "_Labelable__label") and self.__label is not None
def setLabel(self,label):
def setLabel(self,label:str):
if self.hasLabel_() :
if self.label_() != label and label is not None :
msg100 = "Attempting to change the label of a Labelable object, possibly an attribute of the Process\n"
Expand All @@ -547,7 +548,7 @@ def setLabel(self,label):
msg112 = " 4. Compose Sequences: newName = cms.Sequence(oldName)\n"
raise ValueError(msg100+msg101+msg102+msg103+msg104+msg105+msg106+msg107+msg108+msg109+msg110+msg111+msg112)
self.__label = label
def label(self):
def label(self) -> str:
#print "WARNING: _Labelable::label() needs to be changed to label_()"
return self.__label
def __str__(self):
Expand All @@ -557,7 +558,7 @@ def __str__(self):
return str(self.__label)
def dumpSequenceConfig(self):
return str(self.__label)
def dumpSequencePython(self, options=PrintOptions()):
def dumpSequencePython(self, options:PrintOptions=PrintOptions()):
if options.isCfg:
return 'process.'+str(self.__label)
else:
Expand Down Expand Up @@ -599,7 +600,7 @@ def __setitem__(self,key,value):
raise TypeError("can not insert the type "+str(type(value))+" in container "+self._labelIfAny())
super(_ValidatingListBase,self).__setitem__(key,value)
@classmethod
def _isValid(cls,seq):
def _isValid(cls,seq) -> bool:
# see if strings get reinterpreted as lists
if isinstance(seq, str):
return False
Expand Down Expand Up @@ -633,7 +634,7 @@ def insert(self,i,x):
if not self._itemIsValid(x):
raise TypeError("wrong type being inserted to container "+self._labelIfAny())
super(_ValidatingListBase,self).insert(i,self._itemFromArgument(x))
def _labelIfAny(self):
def _labelIfAny(self) -> str:
result = type(self).__name__
if hasattr(self, '__label'):
result += ' ' + self.__label
Expand All @@ -654,7 +655,7 @@ def setValue(self,v):
self[:] = []
self.extend(v)
self._isModified=True
def configValue(self, options=PrintOptions()):
def configValue(self, options:PrintOptions=PrintOptions()) -> str:
config = '{\n'
first = True
for value in iter(self):
Expand All @@ -667,13 +668,13 @@ def configValue(self, options=PrintOptions()):
options.unindent()
config += options.indentation()+'}\n'
return config
def configValueForItem(self,item, options):
def configValueForItem(self,item, options:PrintOptions) -> str:
return str(item)
def pythonValueForItem(self,item, options):
def pythonValueForItem(self,item, options:PrintOptions) -> str:
return self.configValueForItem(item, options)
def __repr__(self):
return self.dumpPython()
def dumpPython(self, options=PrintOptions()):
def dumpPython(self, options:PrintOptions=PrintOptions()) -> str:
specialImportRegistry.registerUse(self)
result = self.pythonTypeName()+"("
n = len(self)
Expand Down
Loading