diff --git a/src/vectorizationbase_compat/contract_pass.jl b/src/vectorizationbase_compat/contract_pass.jl index fd11ffbf..787cc3d7 100644 --- a/src/vectorizationbase_compat/contract_pass.jl +++ b/src/vectorizationbase_compat/contract_pass.jl @@ -11,42 +11,32 @@ function mulexprcost(@nospecialize(x::ProdArg))::Int return base + length(ex.args) end end -function mul_fast_expr(args::SubArray{Any, 1, Vector{Any}, Tuple{UnitRange{Int64}}, true})::Expr +function mul_fast_expr( + args::SubArray{Any,1,Vector{Any},Tuple{UnitRange{Int64}},true} +)::Expr b = Expr(:call, :mul_fast) for i ∈ 2:length(args) push!(b.args, args[i]) end b end -function mulexpr(mulexargs::SubArray{Any, 1, Vector{Any}, Tuple{UnitRange{Int64}}, true})::Tuple{ProdArg,ProdArg} +function mulexpr( + mulexargs::SubArray{Any,1,Vector{Any},Tuple{UnitRange{Int64}},true} +)::Tuple{ProdArg,ProdArg} a = (mulexargs[1])::ProdArg - if length(mulexargs) == 2 - return (a, mulexargs[2]::ProdArg) - elseif length(mulexargs) == 3 - # We'll calc the product between the guesstimated cheaper two args first, for better out of order execution - b = (mulexargs[2])::ProdArg - c = (mulexargs[3])::ProdArg - ac = mulexprcost(a) - bc = mulexprcost(b) - cc = mulexprcost(c) - maxc = max(ac, bc, cc) - if ac == maxc - return (a, Expr(:call, :mul_fast, b, c)) - elseif bc == maxc - return (b, Expr(:call, :mul_fast, a, c)) - else - return (c, Expr(:call, :mul_fast, a, b)) - end - else - return (a, mul_fast_expr(mulexargs)) - end - a = (mulexargs[1])::Union{Symbol,Expr,Number} - b = if length(mulexargs) == 2 # two arg mul - (mulexargs[2])::Union{Symbol,Expr,Number} - else - mul_fast_expr(mulexargs) - end - a, b + Nexpr = length(mulexargs) + Nexpr == 2 && (a, mulexargs[2]::ProdArg) + Nexpr != 3 && (a, mul_fast_expr(mulexargs)) + # We'll calc the product between the guesstimated cheaper two args first, for better out of order execution + b = (mulexargs[2])::ProdArg + c = (mulexargs[3])::ProdArg + ac = mulexprcost(a) + bc = mulexprcost(b) + cc = mulexprcost(c) + maxc = max(ac, bc, cc) + ac == maxc && return (a, Expr(:call, :mul_fast, b, c)) + bc == maxc && return (b, Expr(:call, :mul_fast, c, a)) + return (c, Expr(:call, :mul_fast, a, b)) end function append_args_skip!(call, args, i, mod) for j ∈ eachindex(args) @@ -228,7 +218,8 @@ function capture_a_muladd(ex::Expr, mod) end true, call end -capture_muladd(ex::Expr, mod) = while true +capture_muladd(ex::Expr, mod) = + while true ex.head === :ref && return ex if Meta.isexpr(ex, :call, 2) if (ex.args[1] === :(-))