Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

collect doesn't preserve shape on Broadcased objects #43847

Closed
oxinabox opened this issue Jan 17, 2022 · 5 comments
Closed

collect doesn't preserve shape on Broadcased objects #43847

oxinabox opened this issue Jan 17, 2022 · 5 comments
Labels
broadcast Applying a function over a collection collections Data structures holding multiple items, e.g. sets

Comments

@oxinabox
Copy link
Contributor

consider a 2D generator

julia> g = (1 for x in 1:2, y in 2:3)
Base.Generator{Base.Iterators.ProductIterator{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, var"#47#48"}(var"#47#48"(), Base.Iterators.ProductIterator{Tuple{UnitRange{Int64}, UnitRange{Int64}}}((1:2, 2:3)))

julia> size(g)
(2, 2)

julia> collect(g)
2×2 Matrix{Int64}:
 1  1
 1  1

All is well.

But for a Broadcasted object the shape is not preserved and collect gives back something flat.

julia> h = Base.broadcasted(sqrt, [1 2; 3 4])
Base.Broadcast.Broadcasted(sqrt, ([1 2; 3 4],))

julia> size(h)
(2, 2)

julia> collect(h)
4-element Vector{Float64}:
 1.0
 1.7320508075688772
 1.4142135623730951
 2.0

This can be fixed by materializing first

julia> collect(Base.materialize(h))
2×2 Matrix{Float64}:
 1.0      1.41421
 1.73205  2.0

But why does it not work in the first place.
I suspect it is because the IteratorSize trait isn't set to HasShape.
But why isn't it?

julia> Base.IteratorSize(g)
Base.HasShape{2}()

julia> Base.IteratorSize(h)
Base.HasLength()

It seems to just be hitting the default fallback defintion of IteratorSize
rather than:

Base.IteratorSize(::Type{<:Broadcasted{<:Any,<:NTuple{N,Base.OneTo}}}) where {N} = Base.HasShape{N}()

We do have:

julia> typeof(h)
Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(sqrt), Tuple{Matrix{Int64}}}

And that Base.Broadcast.DefaultArrayStyle{2} tells us what we need to know to define the HasShape

It seems like we could define

IteratorSize(::Type{<:Broadcasted{<:AbstractArrayStyle{N}}}) where {N} = HasShape{N}()

and indeed that does seem to work

julia> collect(h)
2×2 Matrix{Float64}:
 1.0      1.41421
 1.73205  2.0

julia> Base.IteratorSize(::Type{<:Base.Broadcast.Broadcasted{<:Base.Broadcast.AbstractArrayStyle{N}}}) where {N} = Base.HasShape{N}()

julia> collect(h)
2×2 Matrix{Float64}:
 1.0      1.41421
 1.73205  2.0
@tkf
Copy link
Member

tkf commented Jan 18, 2022

You forgot instantiate

julia> h = Base.broadcasted(sqrt, [1 2; 3 4])
Base.Broadcast.Broadcasted(sqrt, ([1 2; 3 4],))

julia> h2 = Broadcast.instantiate(h)
Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}}(sqrt, ([1 2; 3 4],))

julia> collect(h2)
2×2 Matrix{Float64}:
 1.0      1.41421
 1.73205  2.0

@oxinabox
Copy link
Contributor Author

oxinabox commented Jan 18, 2022

How so?

Without instantiate it still defines size and is iterable.
We should honor it's size

@vtjnash
Copy link
Member

vtjnash commented Jan 18, 2022

Agreed. In particular, instantiate is not supposed to have visible side-effects in this way, but merely pre-computes useful information. Can we entirely replace that Base.IteratorSize definition? It seems like it may otherwise be ambiguous (in some cases) with the proposed fix here?

@oxinabox
Copy link
Contributor Author

I am not sure if we can replace it, because it might be being hit by something that is otherwise overloading the broadcast style in a package.
So we may need to include the disambiguating case

@oxinabox
Copy link
Contributor Author

oxinabox commented Jun 6, 2022

Effectively closed by #44061
though some issues with that lead to #45477 being opened .

@oxinabox oxinabox closed this as completed Jun 6, 2022
@nsajko nsajko added broadcast Applying a function over a collection collections Data structures holding multiple items, e.g. sets labels Jan 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
broadcast Applying a function over a collection collections Data structures holding multiple items, e.g. sets
Projects
None yet
Development

No branches or pull requests

4 participants