Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Iterator mypy #1293

Merged
merged 2 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions music21/stream/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from math import isclose
from typing import (Dict, Iterable, List, Optional, Set, Tuple, cast,
TypeVar, Type, Union, Generic, Literal, overload,
Sequence, TYPE_CHECKING)
Sequence, TYPE_CHECKING, Any)

from music21 import base

Expand Down Expand Up @@ -2882,6 +2882,7 @@ def replaceDerived(startSite=self):
target = self._endElements[i - eLen]
self._endElements[i - eLen] = replacement

# noinspection PyTypeChecker
self.coreSetElementOffset(replacement, OffsetSpecial.AT_END, addElement=True)
replacement.sites.add(self)

Expand Down Expand Up @@ -6501,7 +6502,7 @@ def haveAccidentalsBeenMade(self):
'''
return self.streamStatus.accidentals

def makeNotation(self,
def makeNotation(self: StreamType,
*,
meterStream=None,
refStreamOrTimeRange=None,
Expand Down Expand Up @@ -6550,6 +6551,7 @@ def makeNotation(self,
'final'
'''
# determine what is the object to work on first
returnStream: Union[StreamType, Stream[Any]]
if inPlace:
returnStream = self
else:
Expand Down Expand Up @@ -6578,7 +6580,7 @@ def makeNotation(self,
inPlace=True,
bestClef=bestClef)

measureStream = returnStream.getElementsByClass(Measure).stream()
measureStream: Stream[Measure] = returnStream.getElementsByClass(Measure).stream()
# environLocal.printDebug(['Stream.makeNotation(): post makeMeasures,
# length', len(returnStream)])
if not measureStream:
Expand Down Expand Up @@ -13782,6 +13784,8 @@ def makeNotation(self,
# no matter, let's just be extra cautious and run this here (Feb 2021 - JTW)
returnStream.coreElementsChanged()
else: # call the base method
if TYPE_CHECKING:
assert isinstance(returnStream, Score)
super(Score, returnStream).makeNotation(meterStream=meterStream,
refStreamOrTimeRange=refStreamOrTimeRange,
inPlace=True,
Expand Down
1 change: 1 addition & 0 deletions music21/stream/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ def coreStoreAtEnd(self, element, setActiveSite=True):
Core method for adding end elements.
To be called by other methods.
'''
# noinspection PyTypeChecker
self.coreSetElementOffset(element, OffsetSpecial.AT_END, addElement=True)
element.sites.add(self)
# need to explicitly set the activeSite of the element
Expand Down
43 changes: 27 additions & 16 deletions music21/stream/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
# import inspect
import unittest
from math import inf
from typing import Optional, TypeVar

from music21 import common
from music21.common.numberTools import opFrac
from music21.exceptions21 import Music21Exception
from music21 import prebase


StreamIteratorType = TypeVar('StreamIteratorType', bound='music21.stream.iterator.StreamIterator')

class FilterException(Music21Exception):
pass
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -70,7 +74,7 @@ class StreamFilter(prebase.ProtoM21Object):
# def reset(self):
# pass

def __call__(self, item, iterator):
def __call__(self, item, iterator: Optional[StreamIteratorType] = None):
return True

class IsFilter(StreamFilter):
Expand Down Expand Up @@ -128,7 +132,7 @@ def __init__(self, target=()):
def reset(self):
self.numToFind = len(self.target)

def __call__(self, item, iterator):
def __call__(self, item, iterator=None):
if self.numToFind == 0: # short circuit -- we already have
raise StopIteration

Expand Down Expand Up @@ -182,7 +186,7 @@ def __init__(self, target=()):
def reset(self):
pass # do nothing: inf - 1 = inf

def __call__(self, item, iterator):
def __call__(self, item, iterator=None):
return not super().__call__(item, iterator)


Expand All @@ -204,7 +208,7 @@ def __init__(self, searchId=None):
searchIdLower = searchId
self.searchId = searchIdLower

def __call__(self, item, iterator):
def __call__(self, item, iterator=None):
if item.id == self.searchId:
return True
else:
Expand Down Expand Up @@ -259,7 +263,7 @@ def __eq__(self, other):
return False
return True

def __call__(self, item, iterator):
def __call__(self, item, iterator=None):
return not item.classSet.isdisjoint(self.classList)

def _reprInternal(self):
Expand Down Expand Up @@ -289,7 +293,7 @@ class ClassNotFilter(ClassFilter):
'''
derivationStr = 'getElementsNotOfClass'

def __call__(self, item, iterator):
def __call__(self, item, iterator=None):
return item.classSet.isdisjoint(self.classList)


Expand Down Expand Up @@ -328,7 +332,7 @@ def __init__(self, groupFilterList=()):
groupFilterList = [groupFilterList]
self.groupFilterList = groupFilterList

def __call__(self, item, iterator):
def __call__(self, item, iterator=None):
eGroups = item.groups
for groupName in self.groupFilterList:
if groupName in eGroups:
Expand Down Expand Up @@ -383,15 +387,19 @@ def _reprInternal(self) -> str:
return str(self.offsetStart) + '-' + str(self.offsetEnd)


def __call__(self, e, iterator):
s = iterator.srcStream
if s is e:
return False
offset = s.elementOffset(e)
if s.isSorted:
stopAfterEnd = self.stopAfterEnd
def __call__(self, e, iterator=None):
if iterator is None:
offset = e.offset
stopAfterEnd = False
else:
stopAfterEnd = False # never stop after end on unsorted stream
s = iterator.srcStream
if s is e:
return False
offset = s.elementOffset(e)
if s.isSorted:
stopAfterEnd = self.stopAfterEnd
else:
stopAfterEnd = False # never stop after end on unsorted stream
return self.isElementOffsetInRange(e, offset, stopAfterEnd=stopAfterEnd)

def isElementOffsetInRange(self, e, offset, *, stopAfterEnd=False) -> bool:
Expand Down Expand Up @@ -472,7 +480,10 @@ class OffsetHierarchyFilter(OffsetFilter):
'''
derivationStr = 'getElementsByOffsetInHierarchy'

def __call__(self, e, iterator):
def __call__(self, e, iterator=None):
if iterator is None:
raise TypeError('Cannot call OffsetHierarchyFilter without an iterator')

s = iterator.srcStream
if s is e:
return False
Expand Down
61 changes: 23 additions & 38 deletions music21/stream/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

import copy
from typing import (TypeVar, List, Union, Callable, Optional, Literal,
TypedDict, Generic, overload, Iterable, Type, cast)
TypedDict, Generic, overload, Iterable, Type, cast,
Any, TYPE_CHECKING)
import unittest
import warnings

Expand Down Expand Up @@ -126,7 +127,7 @@ class StreamIterator(prebase.ProtoM21Object, Generic[M21ObjType]):
def __init__(self,
srcStream: StreamType,
*,
restrictClass: M21ObjType = base.Music21Object,
# restrictClass: Type[M21ObjType] = base.Music21Object,
filterList: Union[List[FilterType], FilterType, None] = None,
restoreActiveSites: bool = True,
activeInformation: Optional[ActiveInformation] = None,
Expand Down Expand Up @@ -697,14 +698,15 @@ def matchesFilters(self, e) -> bool:
'''
returns False if any filter returns False, True otherwise.
'''
f: Union[Callable[[Any, Optional[Any]], Any], filters.StreamFilter]
for f in self.filters:
f: Union[Callable, filters.StreamFilter]
try:
try:
if f(e, self) is False:
return False
except TypeError: # one element filters are acceptable.
f: Callable
if TYPE_CHECKING:
assert isinstance(f, filters.StreamFilter)
if f(e) is False:
return False
except StopIteration: # pylint: disable=try-except-raise
Expand All @@ -717,10 +719,6 @@ def _newBaseStream(self) -> 'music21.stream.Stream':

>>> s = stream.Stream()

So why does this exist? Since we can't import "music21.stream" here,
we will look in `srcStream.__class__.mro()` for the Stream
object to import.

This is used in places where returnStreamSubclass is False, so we
cannot just call `type(StreamIterator.srcStream)()`

Expand All @@ -729,26 +727,12 @@ def _newBaseStream(self) -> 'music21.stream.Stream':
>>> s = pi._newBaseStream()
>>> s
<music21.stream.Stream 0x1047eb2e8>

>>> pi.srcStream = note.Note()
>>> pi._newBaseStream()
Traceback (most recent call last):
music21.stream.iterator.StreamIteratorException: ...
'''
StreamBase = None
for x in self.srcStream.__class__.mro():
if x.__name__ == 'Stream':
StreamBase = x
break

try:
return StreamBase()
except TypeError: # 'NoneType' object is not callable.
raise StreamIteratorException(
f"You've given a 'stream' that is not a stream! {self.srcStream}")
from music21 import stream
return stream.Stream()

@overload
def stream(self, returnStreamSubClass: Literal[False] = True) -> 'music21.stream.Stream':
def stream(self, returnStreamSubClass: Literal[False]) -> 'music21.stream.Stream':
# ignore this code -- just here until Astroid bug #1015 is fixed
x: 'music21.stream.Stream' = self.streamObj
return x
Expand Down Expand Up @@ -843,7 +827,7 @@ def stream(self, returnStreamSubClass=True) -> Union['music21.stream.Stream', St
else:
derivationMethods = []
for f in self.filters:
if hasattr(f, 'derivationStr'):
if isinstance(f, filters.StreamFilter):
dStr = f.derivationStr
else:
dStr = f.__name__ # function; lambda returns <lambda>
Expand Down Expand Up @@ -968,7 +952,7 @@ def getElementsByClass(self,
classFilterList: Type[ChangedM21ObjType],
*,
returnClone: bool = True) -> StreamIterator[ChangedM21ObjType]:
x: StreamIterator[ChangedM21ObjType] = self.__class__(self.streamObj)
x = cast(StreamIterator[ChangedM21ObjType], self.__class__(self.streamObj))
return x

@overload
Expand Down Expand Up @@ -1517,24 +1501,24 @@ class OffsetIterator(StreamIterator[M21ObjType]):
def __init__(self,
srcStream,
*,
restrictClass: M21ObjType = base.Music21Object,
# restrictClass: Type[M21ObjType] = base.Music21Object,
filterList=None,
restoreActiveSites=True,
activeInformation=None,
ignoreSorting=False
):
super().__init__(srcStream,
restrictClass=restrictClass,
# restrictClass=restrictClass,
filterList=filterList,
restoreActiveSites=restoreActiveSites,
activeInformation=activeInformation,
ignoreSorting=ignoreSorting,
)
self.raiseStopIterationNext = False
self.nextToYield = []
self.nextToYield: List[M21ObjType] = []
self.nextOffsetToYield = None

def __next__(self) -> List[M21ObjType]:
def __next__(self) -> List[M21ObjType]: # type: ignore
if self.raiseStopIterationNext:
raise StopIteration

Expand Down Expand Up @@ -1601,7 +1585,7 @@ def getElementsByClass(self,
classFilterList: Type[ChangedM21ObjType],
*,
returnClone: bool = True) -> OffsetIterator[ChangedM21ObjType]:
x: OffsetIterator[ChangedM21ObjType] = self.__class__(self.streamObj)
x = cast(OffsetIterator[ChangedM21ObjType], self.__class__(self.streamObj))
return x

@overload
Expand Down Expand Up @@ -1699,7 +1683,7 @@ class RecursiveIterator(StreamIterator[M21ObjType]):
def __init__(self,
srcStream,
*,
restrictClass: M21ObjType = base.Music21Object,
# restrictClass: Type[M21ObjType] = base.Music21Object,
filterList=None,
restoreActiveSites=True,
activeInformation=None,
Expand All @@ -1708,7 +1692,7 @@ def __init__(self,
ignoreSorting=False
): # , parentIterator=None):
super().__init__(srcStream,
restrictClass=restrictClass,
# restrictClass=restrictClass,
filterList=filterList,
restoreActiveSites=restoreActiveSites,
activeInformation=activeInformation,
Expand All @@ -1726,7 +1710,7 @@ def __init__(self,

if streamsOnly is True:
self.filters.append(filters.ClassFilter('Stream'))
self.childRecursiveIterator = None
self.childRecursiveIterator: Optional[RecursiveIterator[Any]] = None
# not yet used.
# self.parentIterator = None

Expand Down Expand Up @@ -1754,7 +1738,7 @@ def __next__(self) -> M21ObjType:
self.activeInformation['index'] = -1
self.activeInformation['lastYielded'] = self.srcStream
self.returnSelf = False
return self.srcStream
return cast(M21ObjType, self.srcStream)

elif self.returnSelf is True:
self.returnSelf = False
Expand All @@ -1777,7 +1761,7 @@ def __next__(self) -> M21ObjType:
# in a recursive filter, the stream does not need to match the filter,
# only the internal elements.
if e.isStream:
self.childRecursiveIterator = RecursiveIterator(
childRecursiveIterator: RecursiveIterator[M21ObjType] = RecursiveIterator(
srcStream=e,
restoreActiveSites=self.restoreActiveSites,
filterList=self.filters, # shared list...
Expand All @@ -1789,7 +1773,8 @@ def __next__(self) -> M21ObjType:
newStartOffset = (self.iteratorStartOffsetInHierarchy
+ self.srcStream.elementOffset(e))

self.childRecursiveIterator.iteratorStartOffsetInHierarchy = newStartOffset
childRecursiveIterator.iteratorStartOffsetInHierarchy = newStartOffset
self.childRecursiveIterator = childRecursiveIterator
if self.matchesFilters(e) is False:
continue

Expand Down