-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Preserve shape when collecting broadcasted objects #44061
Preserve shape when collecting broadcasted objects #44061
Conversation
95a5deb
to
faa8586
Compare
For me BTW, not all |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't think of any reason not to do this. @mbauman you have any objects?
Thank you @N5N3, you suggestion seems to work great! Something to note, this does broaden the definition of
effectively becomes
I don't believe it's a problem but thought it was worth noting. |
I tested locally with nest bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3)));
bc = Base.broadcasted(+, bc , bc);
bc = Base.broadcasted(+, bc , bc);
@inferred(Base.IteratorSize(bc)) # error on 9c82a3a The easist solution is adding Base.@pure _maxndims(T) = mapfoldl(_ndims, max, fieldtypes(T)) # _fieldtypes is unneeded anymore I'm not sure is Also for consistency, it would be good to add |
base/broadcast.jl
Outdated
N isa Integer && return N | ||
_maxndims(fieldtype(BC, 2)) | ||
end | ||
Base.@pure _maxndims(T) = mapfoldl(_ndims, max, fieldtypes(T)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See Iterators.zip_iteratorsize
for and example of how this is normally implemented
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like nested zip
also has similar problem? on master:
julia> a = Iterators.zip(1:10,1:10)
zip(1:10, 1:10)
julia> b = zip(a, a);
julia> c = zip(b, b);
julia> @code_warntype Base.IteratorSize(c)
MethodInstance for Base.IteratorSize(::Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}, Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}}})
from Base.IteratorSize(x) in Base at generator.jl:92
Arguments
#self#::Core.Const(Base.IteratorSize)
x::Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}, Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}}}
Body::Any
1 ─ %1 = Base.typeof(x)::Core.Const(Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}, Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}}})
│ %2 = Base.IteratorSize(%1)::Any
└── return %2
julia> @code_warntype Base.IteratorSize(b)
MethodInstance for Base.IteratorSize(::Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}})
from Base.IteratorSize(x) in Base at generator.jl:92
Arguments
#self#::Core.Const(Base.IteratorSize)
x::Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}
Body::Base.HasShape{1}
1 ─ %1 = Base.typeof(x)::Core.Const(Base.Iterators.Zip{Tuple{Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}})
│ %2 = Base.IteratorSize(%1)::Core.Const(Base.HasShape{1}())
└── return %2
julia> @code_warntype Base.IteratorSize(a)
MethodInstance for Base.IteratorSize(::Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}})
from Base.IteratorSize(x) in Base at generator.jl:92
Arguments
#self#::Core.Const(Base.IteratorSize)
x::Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}}
Body::Base.HasShape{1}
1 ─ %1 = Base.typeof(x)::Core.Const(Base.Iterators.Zip{Tuple{UnitRange{Int64}, UnitRange{Int64}}})
│ %2 = Base.IteratorSize(%1)::Core.Const(Base.HasShape{1}())
└── return %2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried implementing this as in Iterators.zip_iteratorsize
but can't get nested broadcasts to pass the @inferred
test. Is there a reason @pure
shouldn't be used here @vtjnash ? I'm not familiar with when it's safe or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pure
should not be used. It is never safe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed @pure
and instead defined methods of _maxndims
for small tuples which has helped the inference on nested broadcasts.
The inference through nested broadcasts won't work for more complex cases e.g. where the original broadcasted
and the nested broadcast have >2 args. For example a test like this would fail:
bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3)), AD1(randn(3)))
bc_nest = Base.broadcasted(+, bc , bc, bc)
@test @inferred(Base.IteratorSize(bc_nest)) === Base.HasShape{1}()
My thinking was that perhaps we could live with more complex cases like this being uninferrable, so long as the simpler cases can be inferred.
bumping this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @BSnelling — I'm so sorry this languished. I think this is now a good workaround and can be merged exactly as it stands.
BUT: I think we can go one better by greedily doing a Broadcast.instantiate
on user-constructed broadcasts. That'll take some thinking — it's not done for performance to prevent recursively spending effort constructing axes on inner (fused) broadcasts — but I think when a user constructs a broadcast manually they'll want the axes constructed. That'll take some more doing as it means splitting the internal API from the external one.
So in the meantime (and to ensure this works now), let's re-run CI here (it's been a few months) and get this in.
6a90f31
to
70fc3cd
Compare
Base.IteratorSize(::Type{<:Broadcasted{<:Any,<:NTuple{N,Base.OneTo}}}) where {N} = Base.HasShape{N}() | ||
Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}() | ||
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, 2)) | ||
Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N<:Integer} = N |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this line will never be hitted.
So even AbstractArrayStyle
with dimension tracking now use the general fallback above.
Thus we have
julia> Base.broadcasted(randn,) |> collect
ERROR: MethodError: reducing over an empty collection is not allowed; consider supplying `init` to the reducer
The `N<:Integer` constraint was nonsensical, given that `(N === Any) || (N isa Int)`. N5N3 noticed this back in 2022: JuliaLang#44061 (comment) Follow up on JuliaLang#44061. Also xref JuliaLang#45477.
The `N<:Integer` constraint was nonsensical, given that `(N === Any) || (N isa Int)`. N5N3 noticed this back in 2022: JuliaLang#44061 (comment) Follow up on JuliaLang#44061. Also xref JuliaLang#45477.
The `N<:Integer` constraint was nonsensical, given that `(N === Any) || (N isa Int)`. N5N3 noticed this back in 2022: JuliaLang#44061 (comment) Follow up on JuliaLang#44061. Also xref JuliaLang#45477.
The `N<:Integer` constraint was nonsensical, given that `(N === Any) || (N isa Int)`. N5N3 noticed this back in 2022: JuliaLang#44061 (comment) Follow up on JuliaLang#44061. Also xref JuliaLang#45477.
The `N<:Integer` constraint was nonsensical, given that `(N === Any) || (N isa Int)`. N5N3 noticed this back in 2022: #44061 (comment) Follow up on #44061. Also xref #45477.
The `N<:Integer` constraint was nonsensical, given that `(N === Any) || (N isa Int)`. N5N3 noticed this back in 2022: #44061 (comment) Follow up on #44061. Also xref #45477. (cherry picked from commit d3964b6)
…aLang#56999) The `N<:Integer` constraint was nonsensical, given that `(N === Any) || (N isa Int)`. N5N3 noticed this back in 2022: JuliaLang#44061 (comment) Follow up on JuliaLang#44061. Also xref JuliaLang#45477. (cherry picked from commit d3964b6)
Fix for #43847
I first implemented the fix proposed on the issue but as expected this was ambiguous. I'm not sure if my proposal is as general as the initial proposal but it is not ambiguous and results in desired behaviour in a test.
(Replacing #44039)