Skip to content

Commit

Permalink
fix!: remove dropout branching based on size
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 29, 2024
1 parent 6e1553e commit 38f9941
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 43 deletions.
19 changes: 9 additions & 10 deletions src/impl/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,8 @@ function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, p::T,
end

function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray,
p::T, ::True, ::False, invp::T, dims) where {T}
if dropout_shape(x, dims) != size(mask)
depwarn(
"`update_mask` is `Val(false)` but `mask` is not of the same size \
as `LuxLib.dropout_shape(x, dims)`. This has been deprecated and \
will be removed in the next release. Set `update_mask` to \
`Val(true)` to avoid this.", :dropout)
mask, rngₙ = generate_dropout_mask(rng, x, p, invp, dims)
return dropout_dot_mul(x, mask), mask, rngₙ
end
::T, ::True, ::False, invp::T, dims) where {T}
check_dropout_mask_shape_mismatch(x, mask, dims)
return dropout_dot_mul(x, mask), mask, rng
end

Expand All @@ -31,6 +23,13 @@ function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray,
return (x, mask, rng)
end

function check_dropout_mask_shape_mismatch(x::AbstractArray, mask::AbstractArray, dims)
@assert dropout_shape(x, dims)==size(mask) "`mask` is not of the same size as `LuxLib.dropout_shape(x, dims)`."
return nothing
end

CRC.@non_differentiable check_dropout_mask_shape_mismatch(::Any...)

## alpha_dropout
function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::True) where {T}
α = T(-1.7580993408473766)
Expand Down
34 changes: 1 addition & 33 deletions test/common_ops/dropout_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@
end

@testitem "Dropout with Preset Mask" tags=[:other_ops] setup=[SharedTestSetup] begin
Enzyme.API.runtimeActivity!(true) # TODO: remove in 1.0 after deprecation

using Statistics

rng = StableRNG(12345)
Expand Down Expand Up @@ -100,8 +98,7 @@ end

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon())))
# Branching based on runtime values
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any

__f = let rng = rng, mask = mask
x -> sum(first(dropout(
Expand All @@ -115,35 +112,6 @@ end
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType

# Try using mask if possible (not possible!!)
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())

@test y isa aType{T, length(x_shape)}
@test size(y) == x_shape
@test mask_ isa aType{T, length(x_shape)}
@test size(mask_) == x_shape
@test rng != rng_
@test mask != mask_

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon())))
# Branching based on runtime activity
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true

__f = let rng = rng, mask = mask
x -> sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
# Testing Mode
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any
Expand Down

0 comments on commit 38f9941

Please sign in to comment.