Skip to content

Commit

Permalink
fix #30643, correctly propagate iterator traits through Stateful (#30644
Browse files Browse the repository at this point in the history
)

(cherry picked from commit 21dfef3)
  • Loading branch information
denizyuret authored and KristofferC committed Jan 11, 2019
1 parent 8ee59bc commit 543cf24
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
5 changes: 2 additions & 3 deletions base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1094,10 +1094,9 @@ end

@inline peek(s::Stateful, sentinel=nothing) = s.nextvalstate !== nothing ? s.nextvalstate[1] : sentinel
@inline iterate(s::Stateful, state=nothing) = s.nextvalstate === nothing ? nothing : (popfirst!(s), nothing)
IteratorSize(::Type{Stateful{VS,T}} where VS) where {T} =
isa(IteratorSize(T), SizeUnknown) ? SizeUnknown() : HasLength()
IteratorSize(::Type{Stateful{T,VS}}) where {T,VS} = IteratorSize(T) isa HasShape ? HasLength() : IteratorSize(T)
eltype(::Type{Stateful{T, VS}} where VS) where {T} = eltype(T)
IteratorEltype(::Type{Stateful{VS,T}} where VS) where {T} = IteratorEltype(T)
IteratorEltype(::Type{Stateful{T,VS}}) where {T,VS} = IteratorEltype(T)
length(s::Stateful) = length(s.itr) - s.taken

end
23 changes: 23 additions & 0 deletions test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,3 +549,26 @@ end
@test ps isa Iterators.Pairs
@test collect(ps) == [1 => :a, 2 => :b]
end

@testset "Stateful fix #30643" begin
@test Base.IteratorSize(1:10) isa Base.HasShape
a = Iterators.Stateful(1:10)
@test Base.IteratorSize(a) isa Base.HasLength
@test length(a) == 10
@test length(collect(a)) == 10
@test length(a) == 0
b = Iterators.Stateful(Iterators.take(1:10,3))
@test Base.IteratorSize(b) isa Base.HasLength
@test length(b) == 3
@test length(collect(b)) == 3
@test length(b) == 0
c = Iterators.Stateful(Iterators.countfrom(1))
@test Base.IteratorSize(c) isa Base.IsInfinite
@test length(Iterators.take(c,3)) == 3
@test length(collect(Iterators.take(c,3))) == 3
d = Iterators.Stateful(Iterators.filter(isodd,1:10))
@test Base.IteratorSize(d) isa Base.SizeUnknown
@test length(collect(Iterators.take(d,3))) == 3
@test length(collect(d)) == 2
@test length(collect(d)) == 0
end

0 comments on commit 543cf24

Please sign in to comment.