diff --git a/music21/stream/base.py b/music21/stream/base.py index 11f0600c8a..074b34dff1 100644 --- a/music21/stream/base.py +++ b/music21/stream/base.py @@ -35,7 +35,7 @@ from math import isclose from typing import (Dict, Iterable, List, Optional, Set, Tuple, cast, TypeVar, Type, Union, Generic, Literal, overload, - Sequence, TYPE_CHECKING) + Sequence, TYPE_CHECKING, Any) from music21 import base @@ -2882,6 +2882,7 @@ def replaceDerived(startSite=self): target = self._endElements[i - eLen] self._endElements[i - eLen] = replacement + # noinspection PyTypeChecker self.coreSetElementOffset(replacement, OffsetSpecial.AT_END, addElement=True) replacement.sites.add(self) @@ -6501,7 +6502,7 @@ def haveAccidentalsBeenMade(self): ''' return self.streamStatus.accidentals - def makeNotation(self, + def makeNotation(self: StreamType, *, meterStream=None, refStreamOrTimeRange=None, @@ -6550,6 +6551,7 @@ def makeNotation(self, 'final' ''' # determine what is the object to work on first + returnStream: Union[StreamType, Stream[Any]] if inPlace: returnStream = self else: @@ -6578,7 +6580,7 @@ def makeNotation(self, inPlace=True, bestClef=bestClef) - measureStream = returnStream.getElementsByClass(Measure).stream() + measureStream: Stream[Measure] = returnStream.getElementsByClass(Measure).stream() # environLocal.printDebug(['Stream.makeNotation(): post makeMeasures, # length', len(returnStream)]) if not measureStream: @@ -13782,6 +13784,8 @@ def makeNotation(self, # no matter, let's just be extra cautious and run this here (Feb 2021 - JTW) returnStream.coreElementsChanged() else: # call the base method + if TYPE_CHECKING: + assert isinstance(returnStream, Score) super(Score, returnStream).makeNotation(meterStream=meterStream, refStreamOrTimeRange=refStreamOrTimeRange, inPlace=True, diff --git a/music21/stream/core.py b/music21/stream/core.py index 79d34dfa63..75f038b967 100644 --- a/music21/stream/core.py +++ b/music21/stream/core.py @@ -467,6 +467,7 @@ def coreStoreAtEnd(self, element, setActiveSite=True): Core method for adding end elements. To be called by other methods. ''' + # noinspection PyTypeChecker self.coreSetElementOffset(element, OffsetSpecial.AT_END, addElement=True) element.sites.add(self) # need to explicitly set the activeSite of the element diff --git a/music21/stream/filters.py b/music21/stream/filters.py index 9f7c8c6fc9..f3b38daab2 100644 --- a/music21/stream/filters.py +++ b/music21/stream/filters.py @@ -19,12 +19,16 @@ # import inspect import unittest from math import inf +from typing import Optional, TypeVar from music21 import common from music21.common.numberTools import opFrac from music21.exceptions21 import Music21Exception from music21 import prebase + +StreamIteratorType = TypeVar('StreamIteratorType', bound='music21.stream.iterator.StreamIterator') + class FilterException(Music21Exception): pass # ----------------------------------------------------------------------------- @@ -70,7 +74,7 @@ class StreamFilter(prebase.ProtoM21Object): # def reset(self): # pass - def __call__(self, item, iterator): + def __call__(self, item, iterator: Optional[StreamIteratorType] = None): return True class IsFilter(StreamFilter): @@ -128,7 +132,7 @@ def __init__(self, target=()): def reset(self): self.numToFind = len(self.target) - def __call__(self, item, iterator): + def __call__(self, item, iterator=None): if self.numToFind == 0: # short circuit -- we already have raise StopIteration @@ -182,7 +186,7 @@ def __init__(self, target=()): def reset(self): pass # do nothing: inf - 1 = inf - def __call__(self, item, iterator): + def __call__(self, item, iterator=None): return not super().__call__(item, iterator) @@ -204,7 +208,7 @@ def __init__(self, searchId=None): searchIdLower = searchId self.searchId = searchIdLower - def __call__(self, item, iterator): + def __call__(self, item, iterator=None): if item.id == self.searchId: return True else: @@ -259,7 +263,7 @@ def __eq__(self, other): return False return True - def __call__(self, item, iterator): + def __call__(self, item, iterator=None): return not item.classSet.isdisjoint(self.classList) def _reprInternal(self): @@ -289,7 +293,7 @@ class ClassNotFilter(ClassFilter): ''' derivationStr = 'getElementsNotOfClass' - def __call__(self, item, iterator): + def __call__(self, item, iterator=None): return item.classSet.isdisjoint(self.classList) @@ -328,7 +332,7 @@ def __init__(self, groupFilterList=()): groupFilterList = [groupFilterList] self.groupFilterList = groupFilterList - def __call__(self, item, iterator): + def __call__(self, item, iterator=None): eGroups = item.groups for groupName in self.groupFilterList: if groupName in eGroups: @@ -383,15 +387,19 @@ def _reprInternal(self) -> str: return str(self.offsetStart) + '-' + str(self.offsetEnd) - def __call__(self, e, iterator): - s = iterator.srcStream - if s is e: - return False - offset = s.elementOffset(e) - if s.isSorted: - stopAfterEnd = self.stopAfterEnd + def __call__(self, e, iterator=None): + if iterator is None: + offset = e.offset + stopAfterEnd = False else: - stopAfterEnd = False # never stop after end on unsorted stream + s = iterator.srcStream + if s is e: + return False + offset = s.elementOffset(e) + if s.isSorted: + stopAfterEnd = self.stopAfterEnd + else: + stopAfterEnd = False # never stop after end on unsorted stream return self.isElementOffsetInRange(e, offset, stopAfterEnd=stopAfterEnd) def isElementOffsetInRange(self, e, offset, *, stopAfterEnd=False) -> bool: @@ -472,7 +480,10 @@ class OffsetHierarchyFilter(OffsetFilter): ''' derivationStr = 'getElementsByOffsetInHierarchy' - def __call__(self, e, iterator): + def __call__(self, e, iterator=None): + if iterator is None: + raise TypeError('Cannot call OffsetHierarchyFilter without an iterator') + s = iterator.srcStream if s is e: return False diff --git a/music21/stream/iterator.py b/music21/stream/iterator.py index df4f7af3c6..cd42c0f641 100644 --- a/music21/stream/iterator.py +++ b/music21/stream/iterator.py @@ -19,7 +19,8 @@ import copy from typing import (TypeVar, List, Union, Callable, Optional, Literal, - TypedDict, Generic, overload, Iterable, Type, cast) + TypedDict, Generic, overload, Iterable, Type, cast, + Any, TYPE_CHECKING) import unittest import warnings @@ -126,7 +127,7 @@ class StreamIterator(prebase.ProtoM21Object, Generic[M21ObjType]): def __init__(self, srcStream: StreamType, *, - restrictClass: M21ObjType = base.Music21Object, + # restrictClass: Type[M21ObjType] = base.Music21Object, filterList: Union[List[FilterType], FilterType, None] = None, restoreActiveSites: bool = True, activeInformation: Optional[ActiveInformation] = None, @@ -697,14 +698,15 @@ def matchesFilters(self, e) -> bool: ''' returns False if any filter returns False, True otherwise. ''' + f: Union[Callable[[Any, Optional[Any]], Any], filters.StreamFilter] for f in self.filters: - f: Union[Callable, filters.StreamFilter] try: try: if f(e, self) is False: return False except TypeError: # one element filters are acceptable. - f: Callable + if TYPE_CHECKING: + assert isinstance(f, filters.StreamFilter) if f(e) is False: return False except StopIteration: # pylint: disable=try-except-raise @@ -717,10 +719,6 @@ def _newBaseStream(self) -> 'music21.stream.Stream': >>> s = stream.Stream() - So why does this exist? Since we can't import "music21.stream" here, - we will look in `srcStream.__class__.mro()` for the Stream - object to import. - This is used in places where returnStreamSubclass is False, so we cannot just call `type(StreamIterator.srcStream)()` @@ -729,26 +727,12 @@ def _newBaseStream(self) -> 'music21.stream.Stream': >>> s = pi._newBaseStream() >>> s - - >>> pi.srcStream = note.Note() - >>> pi._newBaseStream() - Traceback (most recent call last): - music21.stream.iterator.StreamIteratorException: ... ''' - StreamBase = None - for x in self.srcStream.__class__.mro(): - if x.__name__ == 'Stream': - StreamBase = x - break - - try: - return StreamBase() - except TypeError: # 'NoneType' object is not callable. - raise StreamIteratorException( - f"You've given a 'stream' that is not a stream! {self.srcStream}") + from music21 import stream + return stream.Stream() @overload - def stream(self, returnStreamSubClass: Literal[False] = True) -> 'music21.stream.Stream': + def stream(self, returnStreamSubClass: Literal[False]) -> 'music21.stream.Stream': # ignore this code -- just here until Astroid bug #1015 is fixed x: 'music21.stream.Stream' = self.streamObj return x @@ -843,7 +827,7 @@ def stream(self, returnStreamSubClass=True) -> Union['music21.stream.Stream', St else: derivationMethods = [] for f in self.filters: - if hasattr(f, 'derivationStr'): + if isinstance(f, filters.StreamFilter): dStr = f.derivationStr else: dStr = f.__name__ # function; lambda returns @@ -968,7 +952,7 @@ def getElementsByClass(self, classFilterList: Type[ChangedM21ObjType], *, returnClone: bool = True) -> StreamIterator[ChangedM21ObjType]: - x: StreamIterator[ChangedM21ObjType] = self.__class__(self.streamObj) + x = cast(StreamIterator[ChangedM21ObjType], self.__class__(self.streamObj)) return x @overload @@ -1517,24 +1501,24 @@ class OffsetIterator(StreamIterator[M21ObjType]): def __init__(self, srcStream, *, - restrictClass: M21ObjType = base.Music21Object, + # restrictClass: Type[M21ObjType] = base.Music21Object, filterList=None, restoreActiveSites=True, activeInformation=None, ignoreSorting=False ): super().__init__(srcStream, - restrictClass=restrictClass, + # restrictClass=restrictClass, filterList=filterList, restoreActiveSites=restoreActiveSites, activeInformation=activeInformation, ignoreSorting=ignoreSorting, ) self.raiseStopIterationNext = False - self.nextToYield = [] + self.nextToYield: List[M21ObjType] = [] self.nextOffsetToYield = None - def __next__(self) -> List[M21ObjType]: + def __next__(self) -> List[M21ObjType]: # type: ignore if self.raiseStopIterationNext: raise StopIteration @@ -1601,7 +1585,7 @@ def getElementsByClass(self, classFilterList: Type[ChangedM21ObjType], *, returnClone: bool = True) -> OffsetIterator[ChangedM21ObjType]: - x: OffsetIterator[ChangedM21ObjType] = self.__class__(self.streamObj) + x = cast(OffsetIterator[ChangedM21ObjType], self.__class__(self.streamObj)) return x @overload @@ -1699,7 +1683,7 @@ class RecursiveIterator(StreamIterator[M21ObjType]): def __init__(self, srcStream, *, - restrictClass: M21ObjType = base.Music21Object, + # restrictClass: Type[M21ObjType] = base.Music21Object, filterList=None, restoreActiveSites=True, activeInformation=None, @@ -1708,7 +1692,7 @@ def __init__(self, ignoreSorting=False ): # , parentIterator=None): super().__init__(srcStream, - restrictClass=restrictClass, + # restrictClass=restrictClass, filterList=filterList, restoreActiveSites=restoreActiveSites, activeInformation=activeInformation, @@ -1726,7 +1710,7 @@ def __init__(self, if streamsOnly is True: self.filters.append(filters.ClassFilter('Stream')) - self.childRecursiveIterator = None + self.childRecursiveIterator: Optional[RecursiveIterator[Any]] = None # not yet used. # self.parentIterator = None @@ -1754,7 +1738,7 @@ def __next__(self) -> M21ObjType: self.activeInformation['index'] = -1 self.activeInformation['lastYielded'] = self.srcStream self.returnSelf = False - return self.srcStream + return cast(M21ObjType, self.srcStream) elif self.returnSelf is True: self.returnSelf = False @@ -1777,7 +1761,7 @@ def __next__(self) -> M21ObjType: # in a recursive filter, the stream does not need to match the filter, # only the internal elements. if e.isStream: - self.childRecursiveIterator = RecursiveIterator( + childRecursiveIterator: RecursiveIterator[M21ObjType] = RecursiveIterator( srcStream=e, restoreActiveSites=self.restoreActiveSites, filterList=self.filters, # shared list... @@ -1789,7 +1773,8 @@ def __next__(self) -> M21ObjType: newStartOffset = (self.iteratorStartOffsetInHierarchy + self.srcStream.elementOffset(e)) - self.childRecursiveIterator.iteratorStartOffsetInHierarchy = newStartOffset + childRecursiveIterator.iteratorStartOffsetInHierarchy = newStartOffset + self.childRecursiveIterator = childRecursiveIterator if self.matchesFilters(e) is False: continue