Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve shape when collecting broadcasted objects #44061

Merged
merged 6 commits into from
May 27, 2022

Conversation

BSnelling
Copy link
Contributor

Fix for #43847

I first implemented the fix proposed on the issue but as expected this was ambiguous. I'm not sure if my proposal is as general as the initial proposal but it is not ambiguous and results in desired behaviour in a test.

(Replacing #44039)

@N5N3
Copy link
Member

N5N3 commented Feb 8, 2022

For me collect is not the best way to test Broadcasted's shape during iteration.
As copy(bc) should be the officical way to "collect" a Broadcasted.
Something like collect(Iterators.product(bc, bc)) make more sense.

BTW, not all AbstractArrayStyle track bc's dimensionality. (e.g. Broadcast.ArrayStyle)
I guess we will got a error for these kind style when testing collect(Iterators.product(bc, bc))

@vtjnash vtjnash requested a review from mbauman February 10, 2022 21:27
@vtjnash vtjnash added the merge me PR is reviewed. Merge when all tests are passing label Feb 14, 2022
Copy link
Member

@vtjnash vtjnash left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of any reason not to do this. @mbauman you have any objects?

@DilumAluthge DilumAluthge removed the merge me PR is reviewed. Merge when all tests are passing label Feb 17, 2022
@DilumAluthge
Copy link
Member

Removing the merge me label until:

  1. @N5N3 finishes their review
  2. @mbauman weighs in

@BSnelling
Copy link
Contributor Author

BSnelling commented Feb 17, 2022

Thank you @N5N3, you suggestion seems to work great!

Something to note, this does broaden the definition of IteratorSize that was here originally:

Base.IteratorSize(::Type{<:Broadcasted{<:Any,<:NTuple{N,Base.OneTo}}}) where {N} = Base.HasShape{N}()

effectively becomes

Base.IteratorSize(::Type{<:Broadcasted{<:Any,<:NTuple{N,Any}}}) where {N} = Base.HasShape{N}()

I don't believe it's a problem but thought it was worth noting.

@N5N3
Copy link
Member

N5N3 commented Feb 18, 2022

I tested locally with nest Broadcast{<:AbstractArrayStyle{Any}}, e.g.

bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3)));
bc = Base.broadcasted(+, bc , bc);
bc = Base.broadcasted(+, bc , bc);
@inferred(Base.IteratorSize(bc)) # error on 9c82a3a

The easist solution is adding Base.@pure.

Base.@pure _maxndims(T) = mapfoldl(_ndims, max, fieldtypes(T)) # _fieldtypes is unneeded anymore

I'm not sure is @pure OK here. As we'd better avoid using it whenever possible.
But the recursiveness seems unavoidable.

Also for consistency, it would be good to add Base.ndims(bc::Broadcasted) = ndims(typeof(bc)) as ndims should be defined for all Type{<:Broadcast} after 9c82a3a.

N isa Integer && return N
_maxndims(fieldtype(BC, 2))
end
Base.@pure _maxndims(T) = mapfoldl(_ndims, max, fieldtypes(T))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See Iterators.zip_iteratorsize for and example of how this is normally implemented

Copy link
Member

@N5N3 N5N3 Feb 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like nested zip also has similar problem? on master:

julia> a = Iterators.zip(1:10,1:10)
zip(1:10, 1:10)

julia> b = zip(a, a);

julia> c = zip(b, b);

julia> @code_warntype Base.IteratorSize(c)
MethodInstance for Base.IteratorSize(::Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}, Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}}})
  from Base.IteratorSize(x) in Base at generator.jl:92
Arguments
  #self#::Core.Const(Base.IteratorSize)
  x::Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}, Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}}}
Body::Any
1%1 = Base.typeof(x)::Core.Const(Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}, Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}}})
│   %2 = Base.IteratorSize(%1)::Any
└──      return %2


julia> @code_warntype Base.IteratorSize(b)
MethodInstance for Base.IteratorSize(::Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}})
  from Base.IteratorSize(x) in Base at generator.jl:92
Arguments
  #self#::Core.Const(Base.IteratorSize)
  x::Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}
Body::Base.HasShape{1}
1%1 = Base.typeof(x)::Core.Const(Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}})
│   %2 = Base.IteratorSize(%1)::Core.Const(Base.HasShape{1}())
└──      return %2


julia> @code_warntype Base.IteratorSize(a)
MethodInstance for Base.IteratorSize(::Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}})
  from Base.IteratorSize(x) in Base at generator.jl:92
Arguments
  #self#::Core.Const(Base.IteratorSize)
  x::Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}
Body::Base.HasShape{1}
1%1 = Base.typeof(x)::Core.Const(Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}})
│   %2 = Base.IteratorSize(%1)::Core.Const(Base.HasShape{1}())
└──      return %2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried implementing this as in Iterators.zip_iteratorsize but can't get nested broadcasts to pass the @inferred test. Is there a reason @pure shouldn't be used here @vtjnash ? I'm not familiar with when it's safe or not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pure should not be used. It is never safe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed @pure and instead defined methods of _maxndims for small tuples which has helped the inference on nested broadcasts.

The inference through nested broadcasts won't work for more complex cases e.g. where the original broadcasted and the nested broadcast have >2 args. For example a test like this would fail:

bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3)), AD1(randn(3)))
bc_nest = Base.broadcasted(+, bc , bc, bc)
@test @inferred(Base.IteratorSize(bc_nest)) === Base.HasShape{1}()

My thinking was that perhaps we could live with more complex cases like this being uninferrable, so long as the simpler cases can be inferred.

@vtjnash vtjnash requested review from vtjnash and removed request for vtjnash February 22, 2022 17:46
@oxinabox
Copy link
Contributor

bumping this.

Copy link
Member

@mbauman mbauman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @BSnelling — I'm so sorry this languished. I think this is now a good workaround and can be merged exactly as it stands.

BUT: I think we can go one better by greedily doing a Broadcast.instantiate on user-constructed broadcasts. That'll take some thinking — it's not done for performance to prevent recursively spending effort constructing axes on inner (fused) broadcasts — but I think when a user constructs a broadcast manually they'll want the axes constructed. That'll take some more doing as it means splitting the internal API from the external one.

So in the meantime (and to ensure this works now), let's re-run CI here (it's been a few months) and get this in.

@mbauman mbauman force-pushed the bes/collect_broadcasted_2 branch from 6a90f31 to 70fc3cd Compare May 26, 2022 17:14
@mbauman mbauman added the merge me PR is reviewed. Merge when all tests are passing label May 26, 2022
@vtjnash vtjnash merged commit 938da26 into JuliaLang:master May 27, 2022
Base.IteratorSize(::Type{<:Broadcasted{<:Any,<:NTuple{N,Base.OneTo}}}) where {N} = Base.HasShape{N}()
Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}()
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, 2))
Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N<:Integer} = N
Copy link
Member

@N5N3 N5N3 May 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this line will never be hitted.
So even AbstractArrayStyle with dimension tracking now use the general fallback above.
Thus we have

julia> Base.broadcasted(randn,) |> collect
ERROR: MethodError: reducing over an empty collection is not allowed; consider supplying `init` to the reducer

@giordano giordano removed the merge me PR is reviewed. Merge when all tests are passing label Jun 10, 2022
nsajko added a commit to nsajko/julia that referenced this pull request Jan 8, 2025
The `N<:Integer` constraint was nonsensical, given that
`(N === Any) || (N isa Int)`. N5N3 noticed this back in 2022:
JuliaLang#44061 (comment)

Follow up on JuliaLang#44061. Also xref JuliaLang#45477.
@nsajko nsajko added the broadcast Applying a function over a collection label Jan 8, 2025
nsajko added a commit to nsajko/julia that referenced this pull request Jan 8, 2025
The `N<:Integer` constraint was nonsensical, given that
`(N === Any) || (N isa Int)`. N5N3 noticed this back in 2022:
JuliaLang#44061 (comment)

Follow up on JuliaLang#44061. Also xref JuliaLang#45477.
nsajko added a commit to nsajko/julia that referenced this pull request Jan 9, 2025
The `N<:Integer` constraint was nonsensical, given that
`(N === Any) || (N isa Int)`. N5N3 noticed this back in 2022:
JuliaLang#44061 (comment)

Follow up on JuliaLang#44061. Also xref JuliaLang#45477.
nsajko added a commit to nsajko/julia that referenced this pull request Jan 9, 2025
The `N<:Integer` constraint was nonsensical, given that
`(N === Any) || (N isa Int)`. N5N3 noticed this back in 2022:
JuliaLang#44061 (comment)

Follow up on JuliaLang#44061. Also xref JuliaLang#45477.
N5N3 pushed a commit that referenced this pull request Jan 10, 2025
The `N<:Integer` constraint was nonsensical, given that `(N === Any) ||
(N isa Int)`. N5N3 noticed this back in 2022:
#44061 (comment)

Follow up on #44061. Also xref #45477.
KristofferC pushed a commit that referenced this pull request Jan 13, 2025
The `N<:Integer` constraint was nonsensical, given that `(N === Any) ||
(N isa Int)`. N5N3 noticed this back in 2022:
#44061 (comment)

Follow up on #44061. Also xref #45477.

(cherry picked from commit d3964b6)
nsajko added a commit to nsajko/julia that referenced this pull request Feb 9, 2025
…aLang#56999)

The `N<:Integer` constraint was nonsensical, given that `(N === Any) ||
(N isa Int)`. N5N3 noticed this back in 2022:
JuliaLang#44061 (comment)

Follow up on JuliaLang#44061. Also xref JuliaLang#45477.

(cherry picked from commit d3964b6)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
broadcast Applying a function over a collection
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants