Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can we get rid of auto-broadcasting of 0D arrays for activations? #608

Open
avik-pal opened this issue Sep 30, 2024 · 4 comments
Open

Can we get rid of auto-broadcasting of 0D arrays for activations? #608

avik-pal opened this issue Sep 30, 2024 · 4 comments

Comments

@avik-pal
Copy link
Member

(Ideally I don't think it should be auto-broadcasting in the first place). But if we just get rid of O-D array broadcasting that solves our problem over at EnzymeAD/Reactant.jl#54

NNlib.jl/src/activations.jl

Lines 752 to 755 in ba29c90

# Define broadcasts for activation functions on arrays
for f in ACTIVATIONS
@eval $(f)(x::AbstractArray, args...) = $(f).(x, args...)
end

Essentially, Reactant needs to treat scalars as OD trackedrarrays but that causes a recursion loop and expectedly the IR has an unreachable (EnzymeAD/Reactant.jl#54 (comment)). This means the only way we can support NNlib activations is to manually copy over all the code for activation functions.

Now I know there isn't a general way to "opt-out" of the broadcasting for 0-D arrays but we can just define the broadcasting for N=1..10 and hope no one is using an 11+D tensor.

@avik-pal avik-pal changed the title Can we get rid of forwarding broadcasting of 0D arrays for activations? Can we get rid of auto-broadcasting of 0D arrays for activations? Sep 30, 2024
@mcabbott
Copy link
Member

I don't like the auto-broadcast either but here we are.

The built-in opt-out is this function -- which perhaps Reactant needs to know how to handle anyway?

julia> Base.Broadcast.broadcast_preserving_zero_d(sin, fill(pi/2))
0-dimensional Array{Float64, 0}:
1.0

julia> Base.Broadcast.broadcast_preserving_zero_d(sin, [0, pi/2])
2-element Vector{Float64}:
 0.0
 1.0

@avik-pal
Copy link
Member Author

This will still cause issues, right? I want the OD case to be forwarded to the original call without any broadcasting. For example, for relu I want relu(x) = max(x, 0) to be called instead of relu(x::AbstractArray) = relu.(x)

@mcabbott
Copy link
Member

mcabbott commented Sep 30, 2024

But what this function is fed is some special fake 0D array which Reactant invents? My hope is that it can also be made to understand that broadcast_preserving_zero_d(sin, x::ZeroDimTrackedArray):: ZeroDimTrackedArray , but short-circuiting the present implementation used for Array{T,0}.

@avik-pal
Copy link
Member Author

I found a solution that would do it:

for nnlib_op in setdiff(Tuple(NNlib.ACTIVATIONS), (:tanh_fast, :sigmoid_fast, :sigmoid, ))
    @eval function NNlib.$(nnlib_op)(x::TracedRArray{T,0}) where {T}
        return invoke(NNlib.$(nnlib_op), Tuple{Any}, x)
    end
end

what this function is fed is some special fake 0D array which Reactant invents?

correct. Reactant doesn't have a "Number" type, so we treat 0D arrays as a scalar

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants