Skip to content

Commit

Permalink
Generalize broadcast to handle tuples and scalars (#16986)
Browse files Browse the repository at this point in the history
* Generalized broadcast arguments

* Naming fixes

* Add some tests

* News and documentation
  • Loading branch information
pabloferz authored and stevengj committed Sep 18, 2016
1 parent a648f4a commit 2193638
Show file tree
Hide file tree
Showing 12 changed files with 209 additions and 91 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ This section lists changes that do not have deprecation warnings.
for `real(z) < 0`, which differs from `log(gamma(z))` by multiples of 2π
in the imaginary part ([#18330]).

* `broadcast` now handles tuples, and treats any argument that is not a tuple
or an array as a "scalar" ([#16986]).

Library improvements
--------------------

Expand Down Expand Up @@ -646,6 +649,7 @@ Language tooling improvements
[#16854]: https://github.com/JuliaLang/julia/issues/16854
[#16953]: https://github.com/JuliaLang/julia/issues/16953
[#16972]: https://github.com/JuliaLang/julia/issues/16972
[#16986]: https://github.com/JuliaLang/julia/issues/16986
[#17033]: https://github.com/JuliaLang/julia/issues/17033
[#17037]: https://github.com/JuliaLang/julia/issues/17037
[#17075]: https://github.com/JuliaLang/julia/issues/17075
Expand Down
4 changes: 2 additions & 2 deletions base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ promote_array_type{S<:Integer}(::typeof(.\), ::Type{S}, ::Type{Bool}, T::Type) =
promote_array_type{S<:Integer}(F, ::Type{S}, ::Type{Bool}, T::Type) = T

for f in (:+, :-, :div, :mod, :&, :|, :$)
@eval ($f){R,S}(A::AbstractArray{R}, B::AbstractArray{S}) =
_elementwise($f, promote_op($f, R, S), A, B)
@eval ($f)(A::AbstractArray, B::AbstractArray) =
_elementwise($f, promote_eltype_op($f, A, B), A, B)
end
function _elementwise(op, ::Type{Any}, A::AbstractArray, B::AbstractArray)
promote_shape(A, B) # check size compatibility
Expand Down
159 changes: 105 additions & 54 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,44 @@ export broadcast_getindex, broadcast_setindex!

## Broadcasting utilities ##

# fallback routines for broadcasting with no arguments or with scalars
# to just produce a scalar result:
# fallback for broadcasting with zero arguments and some special cases
broadcast(f) = f()
broadcast(f, x::Number...) = f(x...)
@inline broadcast(f, x::Number...) = f(x...)
@inline broadcast{N}(f, t::NTuple{N}, ts::Vararg{NTuple{N}}) = map(f, t, ts...)
@inline broadcast(f, As::AbstractArray...) = broadcast_t(f, promote_eltype_op(f, As...), As...)

# special cases for "X .= ..." (broadcast!) assignments
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
broadcast!(f, X::AbstractArray) = fill!(X, f())
broadcast!(f, X::AbstractArray, x::Number...) = fill!(X, f(x...))
function broadcast!{T,S,N}(::typeof(identity), x::AbstractArray{T,N}, y::AbstractArray{S,N})
check_broadcast_shape(size(x), size(y))
check_broadcast_shape(broadcast_indices(x), broadcast_indices(y))
copy!(x, y)
end

## Calculate the broadcast shape of the arguments, or error if incompatible
# logic for deciding the resulting container type
containertype(x) = containertype(typeof(x))
containertype(::Type) = Any
containertype{T<:Tuple}(::Type{T}) = Tuple
containertype{T<:AbstractArray}(::Type{T}) = Array
containertype(ct1, ct2) = promote_containertype(containertype(ct1), containertype(ct2))
@inline containertype(ct1, ct2, cts...) = promote_containertype(containertype(ct1), containertype(ct2, cts...))

promote_containertype(::Type{Array}, ::Type{Array}) = Array
promote_containertype(::Type{Array}, ct) = Array
promote_containertype(ct, ::Type{Array}) = Array
promote_containertype(::Type{Tuple}, ::Type{Any}) = Tuple
promote_containertype(::Type{Any}, ::Type{Tuple}) = Tuple
promote_containertype{T}(::Type{T}, ::Type{T}) = T

## Calculate the broadcast indices of the arguments, or error if incompatible
# array inputs
broadcast_shape() = ()
broadcast_shape(A) = indices(A)
@inline broadcast_shape(A, B...) = broadcast_shape((), indices(A), map(indices, B)...)
broadcast_indices() = ()
broadcast_indices(A) = broadcast_indices(containertype(A), A)
broadcast_indices(::Type{Any}, A) = ()
broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
broadcast_indices(::Type{Array}, A) = indices(A)
@inline broadcast_indices(A, B...) = broadcast_shape((), broadcast_indices(A), map(broadcast_indices, B)...)
# shape (i.e., tuple-of-indices) inputs
broadcast_shape(shape::Tuple) = shape
@inline broadcast_shape(shape::Tuple, shape1::Tuple, shapes::Tuple...) = broadcast_shape(_bcs((), shape, shape1), shapes...)
Expand All @@ -50,24 +69,21 @@ _bcsm(a, b) = a == b || length(b) == 1
_bcsm(a, b::Number) = b == 1
_bcsm(a::Number, b::Number) = a == b || b == 1

## Check that all arguments are broadcast compatible with shape
## Check that all arguments are broadcast compatible with shape
# comparing one input against a shape
check_broadcast_shape(::Tuple{}) = nothing
check_broadcast_shape(::Tuple{}, A::Union{AbstractArray,Number}) = check_broadcast_shape((), indices(A))
check_broadcast_shape(shp) = nothing
check_broadcast_shape(shp, A) = check_broadcast_shape(shp, indices(A))
check_broadcast_shape(::Tuple{}, ::Tuple{}) = nothing
check_broadcast_shape(shp, ::Tuple{}) = nothing
check_broadcast_shape(::Tuple{}, ::Tuple{}) = nothing
check_broadcast_shape(::Tuple{}, Ashp::Tuple) = throw(DimensionMismatch("cannot broadcast array to have fewer dimensions"))
function check_broadcast_shape(shp, Ashp::Tuple)
_bcsm(shp[1], Ashp[1]) || throw(DimensionMismatch("array could not be broadcast to match destination"))
check_broadcast_shape(tail(shp), tail(Ashp))
end
check_broadcast_indices(shp, A) = check_broadcast_shape(shp, broadcast_indices(A))
# comparing many inputs
@inline function check_broadcast_shape(shp, A, As...)
check_broadcast_shape(shp, A)
check_broadcast_shape(shp, As...)
@inline function check_broadcast_indices(shp, A, As...)
check_broadcast_indices(shp, A)
check_broadcast_indices(shp, As...)
end

## Indexing manipulations
Expand All @@ -83,14 +99,13 @@ end

# newindexer(shape, A) generates `keep` and `Idefault` (for use by
# `newindex` above) for a particular array `A`, given the
# broadcast_shape `shape`
# broadcast_indices `shape`
# `keep` is equivalent to map(==, indices(A), shape) (but see #17126)
newindexer(shape, x::Number) = (), ()
@inline newindexer(shape, A) = newindexer(shape, indices(A))
@inline newindexer(shape, indsA::Tuple{}) = (), ()
@inline function newindexer(shape, indsA::Tuple)
@inline newindexer(shape, A) = shapeindexer(shape, broadcast_indices(A))
@inline shapeindexer(shape, indsA::Tuple{}) = (), ()
@inline function shapeindexer(shape, indsA::Tuple)
ind1 = indsA[1]
keep, Idefault = newindexer(tail(shape), tail(indsA))
keep, Idefault = shapeindexer(tail(shape), tail(indsA))
(shape[1] == ind1, keep...), (first(ind1), Idefault...)
end

Expand All @@ -110,6 +125,10 @@ const bitcache_size = 64 * bitcache_chunks # do not change this
dumpbitcache(Bc::Vector{UInt64}, bind::Int, C::Vector{Bool}) =
Base.copy_to_bitarray_chunks!(Bc, ((bind - 1) << 6) + 1, C, 1, min(bitcache_size, (length(Bc)-bind+1) << 6))

@inline _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I)
@inline _broadcast_getindex(::Type{Any}, A, I) = A
@inline _broadcast_getindex(::Any, A, I) = A[I]

## Broadcasting core
# nargs encodes the number of As arguments (which matches the number
# of keeps). The first two type parameters are to ensure specialization.
Expand All @@ -124,7 +143,7 @@ dumpbitcache(Bc::Vector{UInt64}, bind::Int, C::Vector{Bool}) =
# reverse-broadcast the indices
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i))
# extract array values
@nexprs $nargs i->(@inbounds val_i = A_i[I_i])
@nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i))
# call the function and store the result
@inbounds B[I] = @ncall $nargs f val
end
Expand All @@ -148,7 +167,7 @@ end
# reverse-broadcast the indices
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i))
# extract array values
@nexprs $nargs i->(@inbounds val_i = A_i[I_i])
@nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i))
# call the function and store the result
@inbounds C[ind] = @ncall $nargs f val
ind += 1
Expand Down Expand Up @@ -176,7 +195,7 @@ as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
"""
@inline function broadcast!{nargs}(f, B::AbstractArray, As::Vararg{Any,nargs})
shape = indices(B)
check_broadcast_shape(shape, As...)
check_broadcast_indices(shape, As...)
keeps, Idefaults = map_newindexer(shape, As)
_broadcast!(f, B, keeps, Idefaults, As, Val{nargs})
B
Expand All @@ -196,7 +215,7 @@ end
# reverse-broadcast the indices
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i))
# extract array values
@nexprs $nargs i->(@inbounds val_i = A_i[I_i])
@nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i))
# call the function
V = @ncall $nargs f val
S = typeof(V)
Expand All @@ -219,7 +238,7 @@ end
end

function broadcast_t(f, ::Type{Any}, As...)
shape = broadcast_shape(As...)
shape = broadcast_indices(As...)
iter = CartesianRange(shape)
if isempty(iter)
return similar(Array{Any}, shape)
Expand All @@ -228,19 +247,46 @@ function broadcast_t(f, ::Type{Any}, As...)
keeps, Idefaults = map_newindexer(shape, As)
st = start(iter)
I, st = next(iter, st)
val = f([ As[i][newindex(I, keeps[i], Idefaults[i])] for i=1:nargs ]...)
val = f([ _broadcast_getindex(As[i], newindex(I, keeps[i], Idefaults[i])) for i=1:nargs ]...)
B = similar(Array{typeof(val)}, shape)
B[I] = val
return _broadcast!(f, B, keeps, Idefaults, As, Val{nargs}, iter, st, 1)
end

@inline broadcast_t(f, T, As...) = broadcast!(f, similar(Array{T}, broadcast_shape(As...)), As...)
@inline broadcast_t(f, T, As...) = broadcast!(f, similar(Array{T}, broadcast_indices(As...)), As...)

@generated function broadcast_tup{AT,nargs}(f, As::AT, ::Type{Val{nargs}}, n)
quote
ntuple(n -> (@ncall $nargs f i->_broadcast_getindex(As[i], n)), Val{n})
end
end

function broadcast_c(f, ::Type{Tuple}, As...)
shape = broadcast_indices(As...)
check_broadcast_indices(shape, As...)
n = length(shape[1])
nargs = length(As)
return broadcast_tup(f, As, Val{nargs}, n)
end
@inline broadcast_c(f, ::Type{Any}, a...) = f(a...)
@inline broadcast_c(f, ::Type{Array}, As...) = broadcast_t(f, promote_eltype_op(f, As...), As...)

"""
broadcast(f, As...)
Broadcasts the arrays `As` to a common size by expanding singleton dimensions, and returns
an array of the results `f(as...)` for each position.
Broadcasts the arrays, tuples and/or scalars `As` to a container of the
appropriate type and dimensions. In this context, anything that is not a
subtype of `AbstractArray` or `Tuple` is considered a scalar. The resulting
container is stablished by the following rules:
- If all the arguments are scalars, it returns a scalar.
- If the arguments are tuples and zero or more scalars, it returns a tuple.
- If there is at least an array in the arguments, it returns an array
(and treats tuples as 1-dimensional arrays) expanding singleton dimensions.
A special syntax exists for broadcasting: `f.(args...)` is equivalent to
`broadcast(f, args...)`, and nested `f.(g.(args...))` calls are fused into a
single broadcast loop.
```jldoctest
julia> A = [1, 2, 3, 4, 5]
Expand All @@ -266,27 +312,32 @@ julia> broadcast(+, A, B)
8 9
11 12
14 15
```
"""
@inline broadcast(f, As...) = broadcast_t(f, promote_eltype_op(f, As...), As...)
# alternate, more compact implementation; unfortunately slower.
# also the `collect` machinery doesn't yet support arbitrary index bases.
#=
@generated function _broadcast{nargs}(f, keeps, As, ::Type{Val{nargs}}, iter)
quote
collect((@ncall $nargs f i->As[i][newindex(I, keeps[i])]) for I in iter)
end
end
julia> parse.(Int, ["1", "2"])
2-element Array{Int64,1}:
1
2
function broadcast(f, As...)
shape = broadcast_shape(As...)
iter = CartesianRange(shape)
keeps, Idefaults = map_newindexer(shape, As)
naT = Val{nfields(As)}
_broadcast(f, keeps, Idefaults, As, naT, iter)
end
=#
julia> abs.((1, -2))
(1,2)
julia> broadcast(+, 1.0, (0, -2.0))
(1.0,-1.0)
julia> broadcast(+, 1.0, (0, -2.0), [1])
2-element Array{Float64,1}:
2.0
0.0
julia> string.(("one","two","three","four"), ": ", 1:4)
4-element Array{String,1}:
"one: 1"
"two: 2"
"three: 3"
"four: 4"
```
"""
@inline broadcast(f, As...) = broadcast_c(f, containertype(As...), As...)

"""
bitbroadcast(f, As...)
Expand All @@ -304,7 +355,7 @@ julia> bitbroadcast(isodd,[1,2,3,4,5])
true
```
"""
@inline bitbroadcast(f, As...) = broadcast!(f, similar(BitArray, broadcast_shape(As...)), As...)
@inline bitbroadcast(f, As...) = broadcast!(f, similar(BitArray, broadcast_indices(As...)), As...)

"""
broadcast_getindex(A, inds...)
Expand Down Expand Up @@ -345,13 +396,13 @@ julia> broadcast_getindex(C,[1,2,10])
15
```
"""
broadcast_getindex(src::AbstractArray, I::AbstractArray...) = broadcast_getindex!(similar(Array{eltype(src)}, broadcast_shape(I...)), src, I...)
broadcast_getindex(src::AbstractArray, I::AbstractArray...) = broadcast_getindex!(similar(Array{eltype(src)}, broadcast_indices(I...)), src, I...)
@generated function broadcast_getindex!(dest::AbstractArray, src::AbstractArray, I::AbstractArray...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
quote
@nexprs $N d->(I_d = I[d])
check_broadcast_shape(indices(dest), $(Isplat...)) # unnecessary if this function is never called directly
check_broadcast_indices(indices(dest), $(Isplat...)) # unnecessary if this function is never called directly
checkbounds(src, $(Isplat...))
@nexprs $N d->(@nexprs $N k->(Ibcast_d_k = indices(I_k, d) == OneTo(1)))
@nloops $N i dest d->(@nexprs $N k->(j_d_k = Ibcast_d_k ? 1 : i_d)) begin
Expand All @@ -374,7 +425,7 @@ position in `X` at the indices in `A` given by the same positions in `inds`.
quote
@nexprs $N d->(I_d = I[d])
checkbounds(A, $(Isplat...))
shape = broadcast_shape($(Isplat...))
shape = broadcast_indices($(Isplat...))
@nextract $N shape d->(length(shape) < d ? OneTo(1) : shape[d])
@nexprs $N d->(@nexprs $N k->(Ibcast_d_k = indices(I_k, d) == 1:1))
if !isa(x, AbstractArray)
Expand Down
1 change: 1 addition & 0 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ maybe_oneto() = OneTo(1)

### From abstractarray.jl: Internal multidimensional indexing definitions ###
getindex(x::Number, i::CartesianIndex{0}) = x
getindex(t::Tuple, I...) = getindex(t, IteratorsMD.flatten(I)...)

# These are not defined on directly on getindex to avoid
# ambiguities for AbstractArray subtypes. See the note in abstractarray.jl
Expand Down
2 changes: 2 additions & 0 deletions base/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ zero{T<:Number}(::Type{T}) = convert(T,0)
one(x::Number) = oftype(x,1)
one{T<:Number}(::Type{T}) = convert(T,1)

_default_type(::Type{Number}) = Int

"""
factorial(n)
Expand Down
4 changes: 2 additions & 2 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ function _mapreducedim!{T,N}(f, op, R::AbstractArray, A::AbstractArray{T,N})
return R
end
indsAt, indsRt = safe_tail(indices(A)), safe_tail(indices(R)) # handle d=1 manually
keep, Idefault = Broadcast.newindexer(indsAt, indsRt)
keep, Idefault = Broadcast.shapeindexer(indsAt, indsRt)
if reducedim1(R, A)
# keep the accumulator as a local variable when reducing along the first dimension
i1 = first(indices1(R))
Expand Down Expand Up @@ -331,7 +331,7 @@ function findminmax!{T,N}(f, Rval, Rind, A::AbstractArray{T,N})
# If we're reducing along dimension 1, for efficiency we can make use of a temporary.
# Otherwise, keep the result in Rval/Rind so that we traverse A in storage order.
indsAt, indsRt = safe_tail(indices(A)), safe_tail(indices(Rval))
keep, Idefault = Broadcast.newindexer(indsAt, indsRt)
keep, Idefault = Broadcast.shapeindexer(indsAt, indsRt)
k = 0
if reducedim1(Rval, A)
i1 = first(indices1(Rval))
Expand Down
2 changes: 1 addition & 1 deletion base/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
rotl90, rotr90, round, scale!, setindex!, similar, size, transpose, tril,
triu, vec, permute!

import Base.Broadcast: broadcast_shape
import Base.Broadcast: broadcast_indices

export AbstractSparseArray, AbstractSparseMatrix, AbstractSparseVector,
SparseMatrixCSC, SparseVector, blkdiag, dense, droptol!, dropzeros!, dropzeros,
Expand Down
Loading

2 comments on commit 2193638

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily benchmark build, I will reply here when finished:

@nanosoldier runbenchmarks(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your benchmark job has completed - possible performance regressions were detected. A full report can be found here. cc @jrevels

Please sign in to comment.