Skip to content

Commit

Permalink
Wrap usages of @ngenerate to improve type inference
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Jan 15, 2014
1 parent 714fa07 commit c6d359b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 18 deletions.
2 changes: 1 addition & 1 deletion base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ function gen_broadcast_body(nd::Int, narrays::Int, f::Function)
@nexprs $narrays k->(@inbounds v_k = @nref $nd A_k d->j_d_k)
@inbounds (@nref $nd B i) = (@ncall $narrays $F v)
end
B
end
end

Expand Down Expand Up @@ -100,6 +99,7 @@ function broadcast!(f::Function, B, As...)
func = broadcast_cache[key]
end
func(B, As...)
B
end
end # let broadcast_cache

Expand Down
32 changes: 22 additions & 10 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

### From array.jl

@ngenerate N function checksize(A::AbstractArray, I::NTuple{N, Any}...)
@ngenerate N function _checksize(A::AbstractArray, I::NTuple{N, Any}...)
@nexprs N d->(size(A, d) == length(I_d) || throw(DimensionMismatch("Index $d has length $(length(I_d)), but size(A, $d) = $(size(A,d))")))
nothing
end
checksize(A, I) = (_checksize(A, I); return nothing)
checksize(A, I, J) = (_checksize(A, I, J); return nothing)
checksize(A, I...) = (_checksize(A, I...); return nothing)

# Version that uses cartesian indexing for src
@ngenerate N function getindex!(dest::Array, src::AbstractArray, I::NTuple{N,Union(Real,AbstractVector)}...)
@ngenerate N function _getindex!(dest::Array, src::AbstractArray, I::NTuple{N,Union(Real,AbstractVector)}...)
checksize(dest, I...)
checkbounds(src, I...)
@nexprs N d->(J_d = to_index(I_d))
Expand All @@ -17,11 +20,10 @@ end
@inbounds dest[k] = (@nref N src j)
k += 1
end
dest
end

# Version that uses linear indexing for src
@ngenerate N function getindex!(dest::Array, src::Array, I::NTuple{N,Union(Real,AbstractVector)}...)
@ngenerate N function _getindex!(dest::Array, src::Array, I::NTuple{N,Union(Real,AbstractVector)}...)
checksize(dest, I...)
checkbounds(src, I...)
@nexprs N d->(J_d = to_index(I_d))
Expand All @@ -33,13 +35,18 @@ end
@inbounds dest[k] = src[offset_0]
k += 1
end
dest
end

@ngenerate N getindex(A::Array, I::NTuple{N,Union(Real,AbstractVector)}...) = getindex!(similar(A, eltype(A), index_shape(I...)), A, I...)
getindex!(dest, src, I) = (_getindex!(dest, src, I); return dest)
getindex!(dest, src, I, J) = (_getindex!(dest, src, I, J); return dest)
getindex!(dest, src, I...) = (_getindex!(dest, src, I...); return dest)

getindex(A::Array, I::Union(Real,AbstractVector)) = getindex!(similar(A, index_shape(I)), A, I)
getindex(A::Array, I::Union(Real,AbstractVector), J::Union(Real,AbstractVector)) = getindex!(similar(A, index_shape(I,J)), A, I, J)
getindex(A::Array, I::Union(Real,AbstractVector)...) = getindex!(similar(A, index_shape(I...)), A, I...)

@ngenerate N function setindex!(A::Array, x, I::NTuple{N,Union(Real,AbstractArray)}...)

@ngenerate N function _setindex!(A::Array, x, I::NTuple{N,Union(Real,AbstractArray)}...)
checkbounds(A, I...)
@nexprs N d->(J_d = to_index(I_d))
stride_1 = 1
Expand All @@ -59,9 +66,13 @@ end
k += 1
end
end
A
end

setindex!(A::Array, x, I::Union(Real,AbstractArray)) = (_setindex!(A, x, I); return A)
setindex!(A::Array, x, I::Union(Real,AbstractArray), J::Union(Real,AbstractArray)) =
(_setindex!(A, x, I, J); return A)
setindex!(A::Array, x, I::Union(Real,AbstractArray)...) = (_setindex!(A, x, I...); return A)


@ngenerate N function findn{T,N}(A::AbstractArray{T,N})
nnzA = nnz(A)
Expand Down Expand Up @@ -119,13 +130,14 @@ eval(ngenerate(:N, :(setindex!{T}(s::SubArray{T,N}, v, ind::Integer)), gen_setin

### from abstractarray.jl

@ngenerate N function fill!{T,N}(A::AbstractArray{T,N}, x)
@ngenerate N function _fill!{T,N}(A::AbstractArray{T,N}, x)
@nloops N i A begin
@inbounds (@nref N A i) = x
end
return A
end

fill!(A::AbstractArray, x) = (_fill!(A, x); return A)

## code generator for specializing on the number of dimensions ##

#otherbodies are the bodies that reside between loops, if its a 2 dimension array.
Expand Down
20 changes: 13 additions & 7 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ function reducedim!(f::Function, R, A)
func = reducedim_cache[key]
end
func(R, A)
R
end
end # let reducedim_cache

Expand Down Expand Up @@ -104,7 +105,6 @@ function gen_reduction_body(N, f::Function)
@inbounds (@nref $N R j) = ($F)((@nref $N R j), (@nref $N A i))
end
end
R
end
end

Expand All @@ -115,19 +115,25 @@ reduction_init{T}(A::AbstractArray, region, initial::T) = fill!(similar(A,T,redu
# For performance, these bypass reducedim_cache

all(A::AbstractArray{Bool}, region) = all!(reduction_init(A,region,true), A)
eval(ngenerate(:N, :(all!{N}(R::AbstractArray, A::AbstractArray{Bool,N})), N->gen_reduction_body(N, &)))
all!(R, A) = (_all!(R,A); return R)
eval(ngenerate(:N, :(_all!{N}(R::AbstractArray, A::AbstractArray{Bool,N})), N->gen_reduction_body(N, &)))
any(A::AbstractArray{Bool}, region) = any!(reduction_init(A,region,false), A)
eval(ngenerate(:N, :(any!{N}(R::AbstractArray, A::AbstractArray{Bool,N})), N->gen_reduction_body(N, |)))
any!(R, A) = (_any!(R,A); return R)
eval(ngenerate(:N, :(_any!{N}(R::AbstractArray, A::AbstractArray{Bool,N})), N->gen_reduction_body(N, |)))
maximum{T}(A::AbstractArray{T}, region) =
isempty(A) ? similar(A,reduced_dims0(A,region)) : maximum!(reduction_init(A,region,typemin(T)), A)
eval(ngenerate(:N, :(maximum!{T,N}(R::AbstractArray, A::AbstractArray{T,N})), N->gen_reduction_body(N, scalarmax)))
maximum!(R, A) = (_maximum!(R,A); return R)
eval(ngenerate(:N, :(_maximum!{T,N}(R::AbstractArray, A::AbstractArray{T,N})), N->gen_reduction_body(N, scalarmax)))
minimum{T}(A::AbstractArray{T}, region) =
isempty(A) ? similar(A,reduced_dims0(A,region)) : minimum!(reduction_init(A,region,typemax(T)), A)
eval(ngenerate(:N, :(minimum!{T,N}(R::AbstractArray, A::AbstractArray{T,N})), N->gen_reduction_body(N, scalarmin)))
minimum!(R, A) = (_minimum!(R,A); return R)
eval(ngenerate(:N, :(_minimum!{T,N}(R::AbstractArray, A::AbstractArray{T,N})), N->gen_reduction_body(N, scalarmin)))
sum{T}(A::AbstractArray{T}, region) = sum!(reduction_init(A,region,zero(T)), A)
sum(A::AbstractArray{Bool}, region) = sum!(reduction_init(A,region,0), A)
eval(ngenerate(:N, :(sum!{T,N}(R::AbstractArray, A::AbstractArray{T,N})), N->gen_reduction_body(N, +)))
sum!(R, A) = (_sum!(R,A); return R)
eval(ngenerate(:N, :(_sum!{T,N}(R::AbstractArray, A::AbstractArray{T,N})), N->gen_reduction_body(N, +)))
prod{T}(A::AbstractArray{T}, region) = prod!(reduction_init(A,region,one(T)), A)
eval(ngenerate(:N, :(prod!{T,N}(R::AbstractArray, A::AbstractArray{T,N})), N->gen_reduction_body(N, *)))
prod!(R, A) = (_prod!(R,A); return R)
eval(ngenerate(:N, :(_prod!{T,N}(R::AbstractArray, A::AbstractArray{T,N})), N->gen_reduction_body(N, *)))

prod(A::AbstractArray{Bool}, region) = error("use all() instead of prod() for boolean arrays")

0 comments on commit c6d359b

Please sign in to comment.