Skip to content

Commit

Permalink
Merge pull request #1293 from cuthbertLab/iterator-mypy
Browse files Browse the repository at this point in the history
WIP: Iterator mypy
  • Loading branch information
mscuthbert authored May 5, 2022
2 parents 56dc576 + c9505d4 commit 11dc9e6
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 57 deletions.
10 changes: 7 additions & 3 deletions music21/stream/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from math import isclose
from typing import (Dict, Iterable, List, Optional, Set, Tuple, cast,
TypeVar, Type, Union, Generic, Literal, overload,
Sequence, TYPE_CHECKING)
Sequence, TYPE_CHECKING, Any)

from music21 import base

Expand Down Expand Up @@ -2882,6 +2882,7 @@ def replaceDerived(startSite=self):
target = self._endElements[i - eLen]
self._endElements[i - eLen] = replacement

# noinspection PyTypeChecker
self.coreSetElementOffset(replacement, OffsetSpecial.AT_END, addElement=True)
replacement.sites.add(self)

Expand Down Expand Up @@ -6501,7 +6502,7 @@ def haveAccidentalsBeenMade(self):
'''
return self.streamStatus.accidentals

def makeNotation(self,
def makeNotation(self: StreamType,
*,
meterStream=None,
refStreamOrTimeRange=None,
Expand Down Expand Up @@ -6550,6 +6551,7 @@ def makeNotation(self,
'final'
'''
# determine what is the object to work on first
returnStream: Union[StreamType, Stream[Any]]
if inPlace:
returnStream = self
else:
Expand Down Expand Up @@ -6578,7 +6580,7 @@ def makeNotation(self,
inPlace=True,
bestClef=bestClef)

measureStream = returnStream.getElementsByClass(Measure).stream()
measureStream: Stream[Measure] = returnStream.getElementsByClass(Measure).stream()
# environLocal.printDebug(['Stream.makeNotation(): post makeMeasures,
# length', len(returnStream)])
if not measureStream:
Expand Down Expand Up @@ -13782,6 +13784,8 @@ def makeNotation(self,
# no matter, let's just be extra cautious and run this here (Feb 2021 - JTW)
returnStream.coreElementsChanged()
else: # call the base method
if TYPE_CHECKING:
assert isinstance(returnStream, Score)
super(Score, returnStream).makeNotation(meterStream=meterStream,
refStreamOrTimeRange=refStreamOrTimeRange,
inPlace=True,
Expand Down
1 change: 1 addition & 0 deletions music21/stream/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ def coreStoreAtEnd(self, element, setActiveSite=True):
Core method for adding end elements.
To be called by other methods.
'''
# noinspection PyTypeChecker
self.coreSetElementOffset(element, OffsetSpecial.AT_END, addElement=True)
element.sites.add(self)
# need to explicitly set the activeSite of the element
Expand Down
43 changes: 27 additions & 16 deletions music21/stream/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
# import inspect
import unittest
from math import inf
from typing import Optional, TypeVar

from music21 import common
from music21.common.numberTools import opFrac
from music21.exceptions21 import Music21Exception
from music21 import prebase


StreamIteratorType = TypeVar('StreamIteratorType', bound='music21.stream.iterator.StreamIterator')

class FilterException(Music21Exception):
pass
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -70,7 +74,7 @@ class StreamFilter(prebase.ProtoM21Object):
# def reset(self):
# pass

def __call__(self, item, iterator):
def __call__(self, item, iterator: Optional[StreamIteratorType] = None):
return True

class IsFilter(StreamFilter):
Expand Down Expand Up @@ -128,7 +132,7 @@ def __init__(self, target=()):
def reset(self):
self.numToFind = len(self.target)

def __call__(self, item, iterator):
def __call__(self, item, iterator=None):
if self.numToFind == 0: # short circuit -- we already have
raise StopIteration

Expand Down Expand Up @@ -182,7 +186,7 @@ def __init__(self, target=()):
def reset(self):
pass # do nothing: inf - 1 = inf

def __call__(self, item, iterator):
def __call__(self, item, iterator=None):
return not super().__call__(item, iterator)


Expand All @@ -204,7 +208,7 @@ def __init__(self, searchId=None):
searchIdLower = searchId
self.searchId = searchIdLower

def __call__(self, item, iterator):
def __call__(self, item, iterator=None):
if item.id == self.searchId:
return True
else:
Expand Down Expand Up @@ -259,7 +263,7 @@ def __eq__(self, other):
return False
return True

def __call__(self, item, iterator):
def __call__(self, item, iterator=None):
return not item.classSet.isdisjoint(self.classList)

def _reprInternal(self):
Expand Down Expand Up @@ -289,7 +293,7 @@ class ClassNotFilter(ClassFilter):
'''
derivationStr = 'getElementsNotOfClass'

def __call__(self, item, iterator):
def __call__(self, item, iterator=None):
return item.classSet.isdisjoint(self.classList)


Expand Down Expand Up @@ -328,7 +332,7 @@ def __init__(self, groupFilterList=()):
groupFilterList = [groupFilterList]
self.groupFilterList = groupFilterList

def __call__(self, item, iterator):
def __call__(self, item, iterator=None):
eGroups = item.groups
for groupName in self.groupFilterList:
if groupName in eGroups:
Expand Down Expand Up @@ -383,15 +387,19 @@ def _reprInternal(self) -> str:
return str(self.offsetStart) + '-' + str(self.offsetEnd)


def __call__(self, e, iterator):
s = iterator.srcStream
if s is e:
return False
offset = s.elementOffset(e)
if s.isSorted:
stopAfterEnd = self.stopAfterEnd
def __call__(self, e, iterator=None):
if iterator is None:
offset = e.offset
stopAfterEnd = False
else:
stopAfterEnd = False # never stop after end on unsorted stream
s = iterator.srcStream
if s is e:
return False
offset = s.elementOffset(e)
if s.isSorted:
stopAfterEnd = self.stopAfterEnd
else:
stopAfterEnd = False # never stop after end on unsorted stream
return self.isElementOffsetInRange(e, offset, stopAfterEnd=stopAfterEnd)

def isElementOffsetInRange(self, e, offset, *, stopAfterEnd=False) -> bool:
Expand Down Expand Up @@ -472,7 +480,10 @@ class OffsetHierarchyFilter(OffsetFilter):
'''
derivationStr = 'getElementsByOffsetInHierarchy'

def __call__(self, e, iterator):
def __call__(self, e, iterator=None):
if iterator is None:
raise TypeError('Cannot call OffsetHierarchyFilter without an iterator')

s = iterator.srcStream
if s is e:
return False
Expand Down
61 changes: 23 additions & 38 deletions music21/stream/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

import copy
from typing import (TypeVar, List, Union, Callable, Optional, Literal,
TypedDict, Generic, overload, Iterable, Type, cast)
TypedDict, Generic, overload, Iterable, Type, cast,
Any, TYPE_CHECKING)
import unittest
import warnings

Expand Down Expand Up @@ -126,7 +127,7 @@ class StreamIterator(prebase.ProtoM21Object, Generic[M21ObjType]):
def __init__(self,
srcStream: StreamType,
*,
restrictClass: M21ObjType = base.Music21Object,
# restrictClass: Type[M21ObjType] = base.Music21Object,
filterList: Union[List[FilterType], FilterType, None] = None,
restoreActiveSites: bool = True,
activeInformation: Optional[ActiveInformation] = None,
Expand Down Expand Up @@ -697,14 +698,15 @@ def matchesFilters(self, e) -> bool:
'''
returns False if any filter returns False, True otherwise.
'''
f: Union[Callable[[Any, Optional[Any]], Any], filters.StreamFilter]
for f in self.filters:
f: Union[Callable, filters.StreamFilter]
try:
try:
if f(e, self) is False:
return False
except TypeError: # one element filters are acceptable.
f: Callable
if TYPE_CHECKING:
assert isinstance(f, filters.StreamFilter)
if f(e) is False:
return False
except StopIteration: # pylint: disable=try-except-raise
Expand All @@ -717,10 +719,6 @@ def _newBaseStream(self) -> 'music21.stream.Stream':
>>> s = stream.Stream()
So why does this exist? Since we can't import "music21.stream" here,
we will look in `srcStream.__class__.mro()` for the Stream
object to import.
This is used in places where returnStreamSubclass is False, so we
cannot just call `type(StreamIterator.srcStream)()`
Expand All @@ -729,26 +727,12 @@ def _newBaseStream(self) -> 'music21.stream.Stream':
>>> s = pi._newBaseStream()
>>> s
<music21.stream.Stream 0x1047eb2e8>
>>> pi.srcStream = note.Note()
>>> pi._newBaseStream()
Traceback (most recent call last):
music21.stream.iterator.StreamIteratorException: ...
'''
StreamBase = None
for x in self.srcStream.__class__.mro():
if x.__name__ == 'Stream':
StreamBase = x
break

try:
return StreamBase()
except TypeError: # 'NoneType' object is not callable.
raise StreamIteratorException(
f"You've given a 'stream' that is not a stream! {self.srcStream}")
from music21 import stream
return stream.Stream()

@overload
def stream(self, returnStreamSubClass: Literal[False] = True) -> 'music21.stream.Stream':
def stream(self, returnStreamSubClass: Literal[False]) -> 'music21.stream.Stream':
# ignore this code -- just here until Astroid bug #1015 is fixed
x: 'music21.stream.Stream' = self.streamObj
return x
Expand Down Expand Up @@ -843,7 +827,7 @@ def stream(self, returnStreamSubClass=True) -> Union['music21.stream.Stream', St
else:
derivationMethods = []
for f in self.filters:
if hasattr(f, 'derivationStr'):
if isinstance(f, filters.StreamFilter):
dStr = f.derivationStr
else:
dStr = f.__name__ # function; lambda returns <lambda>
Expand Down Expand Up @@ -968,7 +952,7 @@ def getElementsByClass(self,
classFilterList: Type[ChangedM21ObjType],
*,
returnClone: bool = True) -> StreamIterator[ChangedM21ObjType]:
x: StreamIterator[ChangedM21ObjType] = self.__class__(self.streamObj)
x = cast(StreamIterator[ChangedM21ObjType], self.__class__(self.streamObj))
return x

@overload
Expand Down Expand Up @@ -1517,24 +1501,24 @@ class OffsetIterator(StreamIterator[M21ObjType]):
def __init__(self,
srcStream,
*,
restrictClass: M21ObjType = base.Music21Object,
# restrictClass: Type[M21ObjType] = base.Music21Object,
filterList=None,
restoreActiveSites=True,
activeInformation=None,
ignoreSorting=False
):
super().__init__(srcStream,
restrictClass=restrictClass,
# restrictClass=restrictClass,
filterList=filterList,
restoreActiveSites=restoreActiveSites,
activeInformation=activeInformation,
ignoreSorting=ignoreSorting,
)
self.raiseStopIterationNext = False
self.nextToYield = []
self.nextToYield: List[M21ObjType] = []
self.nextOffsetToYield = None

def __next__(self) -> List[M21ObjType]:
def __next__(self) -> List[M21ObjType]: # type: ignore
if self.raiseStopIterationNext:
raise StopIteration

Expand Down Expand Up @@ -1601,7 +1585,7 @@ def getElementsByClass(self,
classFilterList: Type[ChangedM21ObjType],
*,
returnClone: bool = True) -> OffsetIterator[ChangedM21ObjType]:
x: OffsetIterator[ChangedM21ObjType] = self.__class__(self.streamObj)
x = cast(OffsetIterator[ChangedM21ObjType], self.__class__(self.streamObj))
return x

@overload
Expand Down Expand Up @@ -1699,7 +1683,7 @@ class RecursiveIterator(StreamIterator[M21ObjType]):
def __init__(self,
srcStream,
*,
restrictClass: M21ObjType = base.Music21Object,
# restrictClass: Type[M21ObjType] = base.Music21Object,
filterList=None,
restoreActiveSites=True,
activeInformation=None,
Expand All @@ -1708,7 +1692,7 @@ def __init__(self,
ignoreSorting=False
): # , parentIterator=None):
super().__init__(srcStream,
restrictClass=restrictClass,
# restrictClass=restrictClass,
filterList=filterList,
restoreActiveSites=restoreActiveSites,
activeInformation=activeInformation,
Expand All @@ -1726,7 +1710,7 @@ def __init__(self,

if streamsOnly is True:
self.filters.append(filters.ClassFilter('Stream'))
self.childRecursiveIterator = None
self.childRecursiveIterator: Optional[RecursiveIterator[Any]] = None
# not yet used.
# self.parentIterator = None

Expand Down Expand Up @@ -1754,7 +1738,7 @@ def __next__(self) -> M21ObjType:
self.activeInformation['index'] = -1
self.activeInformation['lastYielded'] = self.srcStream
self.returnSelf = False
return self.srcStream
return cast(M21ObjType, self.srcStream)

elif self.returnSelf is True:
self.returnSelf = False
Expand All @@ -1777,7 +1761,7 @@ def __next__(self) -> M21ObjType:
# in a recursive filter, the stream does not need to match the filter,
# only the internal elements.
if e.isStream:
self.childRecursiveIterator = RecursiveIterator(
childRecursiveIterator: RecursiveIterator[M21ObjType] = RecursiveIterator(
srcStream=e,
restoreActiveSites=self.restoreActiveSites,
filterList=self.filters, # shared list...
Expand All @@ -1789,7 +1773,8 @@ def __next__(self) -> M21ObjType:
newStartOffset = (self.iteratorStartOffsetInHierarchy
+ self.srcStream.elementOffset(e))

self.childRecursiveIterator.iteratorStartOffsetInHierarchy = newStartOffset
childRecursiveIterator.iteratorStartOffsetInHierarchy = newStartOffset
self.childRecursiveIterator = childRecursiveIterator
if self.matchesFilters(e) is False:
continue

Expand Down

0 comments on commit 11dc9e6

Please sign in to comment.