Skip to content

Commit

Permalink
Add shareindexes macro for faster iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Mar 4, 2016
1 parent ca6f253 commit 459bfc9
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 0 deletions.
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,7 @@ export
@inbounds,
@fastmath,
@simd,
@shareindexes,
@inline,
@noinline,

Expand Down
70 changes: 70 additions & 0 deletions base/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,73 @@ function collect{I<:IteratorND}(g::Generator{I})
dest[1] = first
return map_to!(g.f, 2, st, dest, g.iter)
end

"""
`@shareindexes` may improve the efficiency of loops that have multiple indexes.
For example,
```
@shareindexes for (IA, IB) in zip(eachindex(A), eachindex(B))
# body
end
```
generates, by default, a loop with two indexes that need to be
incremented and tested on each iteration. However, if it happens that
the two iterators in the `zip` call have the same value, then it runs
a variant of the loop that uses a single index.
"""
macro shareindexes(ex)
_shareindexes(ex)
end

function _shareindexes(ex::Expr)
if ex.head == :block
# Skip to the :for loop
i = 1
while i <= length(ex.args) && (a = ex.args[i]; !isa(a, Expr) || a.head != :for)
i += 1
end
i > length(ex.args) && error("expression must be a for loop")
ex.args[i] = _shareindexes(ex.args[i])
return ex
end
ex.head == :for || error("expression must be a for loop")
iteration, body = ex.args
indexvars, iterex = iteration.args
isa(indexvars, Symbol) && return ex # just one variable
# A couple of sanity checks
indexvars.head == :tuple || error("iteration variables must be expressed as a tuple")
iterex.head == :call && iterex.args[1] == :zip || error("iterators must be zipped")
iteratorexs = iterex.args[2:end]
n = length(indexvars.args)
length(iteratorexs) == n || error("number of indexes does not match the number of iterators")
n != 2 && return ex # just special-case 2 indexes for now
indexsyms = indexvars.args
# Evaluate the iterators
itersyms = [gensym(string("R", i)) for i = 1:n]
iterevals = [Expr(:(=), itersyms[i], iteratorexs[i]) for i = 1:n]
# Prepare a variant using a single index
index1 = gensym(:I)
body1 = itersub(body, indexsyms, index1)
esc(quote
$(iterevals...)
if $(itersyms[1]) == $(itersyms[2])
for $index1 in $(itersyms[1])
$body1
end
else
for $indexvars = zip($(itersyms...))
$body
end
end
end)
end

itersub(ex, indexsyms, replacement) = itersub!(copy(ex), indexsyms, replacement)
function itersub!(ex::Expr, indexsyms, replacement)
for i = 1:length(ex.args)
ex.args[i] = itersub!(ex.args[i], indexsyms, replacement)
end
ex
end
itersub!(sym::Symbol, indexsyms, replacement) = sym in indexsyms ? replacement : sym
itersub!(arg, indexsyms, replacement) = arg
34 changes: 34 additions & 0 deletions test/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,37 @@ let f(g) = (@test size(g.iter)==(2,3))
end

@test_throws DimensionMismatch Base.IteratorND(1:2, (2,3))

# @shareindexes
let
function mydot1(A, B)
s = 0.0
for I in eachindex(A, B)
@inbounds s += A[I]*B[I]
end
s
end

function mydot2(A, B)
s = 0.0
for (IA,IB) in zip(eachindex(A), eachindex(B))
@inbounds s += A[IA]*B[IB]
end
s
end

function mydotshared(A, B)
s = 0.0
@shareindexes for (IA,IB) in zip(eachindex(A), eachindex(B))
@inbounds s += A[IA]*B[IB]
end
s
end

A = rand(3,4)
B = rand(3,4)
AS = sub(A, 1:size(A,1), :) # LinearSlow
BS = sub(B, 1:size(B,1), :)

@test mydot1(A, B) == mydot2(AS, BS) == mydotshared(A, BS)
end

0 comments on commit 459bfc9

Please sign in to comment.