Skip to content

Commit

Permalink
feat: easy mechanism to set preferences
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 28, 2024
1 parent 6fa10f8 commit 8c630e4
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 4 deletions.
2 changes: 0 additions & 2 deletions LocalPreferences.toml

This file was deleted.

4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down Expand Up @@ -75,6 +76,7 @@ Compat = "4.15"
ComponentArrays = "0.15.11"
ConcreteStructs = "0.2.3"
ConstructionBase = "1.5"
DispatchDoctor = "0.4.12"
Documenter = "1.4"
DynamicExpressions = "0.16, 0.17, 0.18"
Enzyme = "0.12.24"
Expand All @@ -95,7 +97,7 @@ LossFunctions = "0.11.1"
LuxCore = "0.1.16"
LuxDeviceUtils = "0.1.26"
LuxLib = "0.3.33"
LuxTestUtils = "0.1.15"
LuxTestUtils = "0.1.18"
MLUtils = "0.4.3"
MPI = "0.20.19"
MacroTools = "0.5.13"
Expand Down
6 changes: 6 additions & 0 deletions docs/src/api/Lux/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ StatefulLuxLayer
@non_trainable
```

## Preferences

```@docs
Lux.set_dispatch_doctor_preferences!
```

## Truncated Stacktraces (Deprecated)

```@docs
Expand Down
7 changes: 7 additions & 0 deletions docs/src/manual/performance_pitfalls.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,10 @@ using:
using GPUArraysCore
GPUArraysCore.allowscalar(false)
```

## Type Instabilities

`Lux.jl` is integrated with `DispatchDoctor.jl` to catch type instabilities. You can easily
enable it by setting the `instability_check` preference. This will help you catch type
instabilities in your code. For more information on how to set preferences, check out
[`set_dispatch_doctor_preferences`](@ref).
11 changes: 11 additions & 0 deletions docs/src/manual/preferences.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,14 @@ By default, both of these preferences are set to `false`.
1. `eltype_mismatch_handling` - Preference controlling what happens when layers get
different eltypes as input. See the documentation on [`match_eltype`](@ref) for more
details.

## [Dispatch Doctor](@id dispatch-doctor-preference)

1. `instability_check` - Preference controlling the dispatch doctor. See the documentation
on [`set_dispatch_doctor_preferences!`](@ref) for more details. The preferences need to
be set for `LuxCore` and `LuxLib` packages. Both of them default to `disable`.
- Setting the `LuxCore` preference sets the check at the level of `LuxCore.apply`. This
essentially activates the dispatch doctor for all Lux layers.
- Setting the `LuxLib` preference sets the check at the level of functional layer of
Lux, for example, [`fused_dense_bias_activation`](@ref). These functions are supposed
to be type stable for common input types and can be used to guarantee type stability.
3 changes: 2 additions & 1 deletion src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using MacroTools: MacroTools, block, combinedef, splitdef
using Markdown: @doc_str
using NNlib: NNlib
using Optimisers: Optimisers
using Preferences: load_preference, has_preference
using Preferences: load_preference, has_preference, set_preferences!
using Random: Random, AbstractRNG
using Reexport: @reexport
using Statistics: mean
Expand Down Expand Up @@ -133,6 +133,7 @@ export MPIBackend, NCCLBackend, DistributedUtils
# Unexported functions that are part of the public API
@compat public Experimental
@compat public xlogx, xlogy
@compat public set_dispatch_doctor_preferences!
@compat(public,
(recursive_add!!, recursive_copyto!, recursive_eltype,
recursive_make_zero, recursive_map, recursive_make_zero!!))
Expand Down
39 changes: 39 additions & 0 deletions src/preferences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,42 @@ const MPI_ROCM_AWARE = @deprecate_preference("LuxDistributedMPIROCMAware", "rocm
# Eltype Auto Conversion
const ELTYPE_MISMATCH_HANDLING = @load_preference_with_choices("eltype_mismatch_handling",
"none", ("none", "warn", "convert", "error"))

# Dispatch Doctor
"""
set_dispatch_doctor_preferences!(mode::String)
set_dispatch_doctor_preferences!(; luxcore::String="disable", luxlib::String="disable")
Set the dispatch doctor preference for `LuxCore` and `LuxLib` packages.
`mode` can be `"disable"`, `"warn"`, or `"error"`. For details on the different modes, see
the [DispatchDoctor.jl](https://astroautomata.com/DispatchDoctor.jl/dev/) documentation.
If the preferences are already set, then no action is taken. Otherwise the preference is
set. For changes to take effect, the Julia session must be restarted.
"""
function set_dispatch_doctor_preferences!(mode::String)
return set_dispatch_doctor_preferences!(; luxcore=mode, luxlib=mode)
end

function set_dispatch_doctor_preferences!(;
luxcore::String="disable", luxlib::String="disable")
_set_dispatch_doctor_preferences!(LuxCore, luxcore)
_set_dispatch_doctor_preferences!(LuxLib, luxlib)
return
end

function _set_dispatch_doctor_preferences!(package, mode::String)
@argcheck mode in ("disable", "warn", "error")
if has_preference(package, "dispatch_doctor")
orig_pref = load_preference(package, "dispatch_doctor")
if orig_pref == mode
@info "Dispatch Doctor preference for $(package) is already set to $mode."
return
end
end
set_preferences!(package, "instability_check" => mode; force=true)
@info "Dispatch Doctor preference for $(package) set to $mode. Please restart Julia \
for this change to take effect."
return
end
23 changes: 23 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,26 @@ if ("all" in LUX_TEST_GROUP || "eltype_match" in LUX_TEST_GROUP)
Test.@test true
end
end

# Set preferences tests
if ("all" in LUX_TEST_GROUP || "others" in LUX_TEST_GROUP)
@testset "DispatchDoctor Preferences" begin
@testset "set_dispatch_doctor_preferences!" begin
@test_throws ArgumentError Lux.set_dispatch_doctor_preferences!("invalid")
@test_throws ArgumentError Lux.set_dispatch_doctor_preferences!(;
luxcore="invalid")

Lux.set_dispatch_doctor_preferences!("disable")
@test Preferences.load_preference(LuxCore, "instability_check") == "disable"
@test Preferences.load_preference(LuxLib, "instability_check") == "disable"

Lux.set_dispatch_doctor_preferences!(; luxcore="warn", luxlib="error")
@test Preferences.load_preference(LuxCore, "instability_check") == "warn"
@test Preferences.load_preference(LuxLib, "instability_check") == "error"

Lux.set_dispatch_doctor_preferences!(; luxcore="error")
@test Preferences.load_preference(LuxCore, "instability_check") == "error"
@test Preferences.load_preference(LuxLib, "instability_check") == "disable"
end
end
end
2 changes: 2 additions & 0 deletions test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ using Lux, Functors
Zygote, Statistics
using LuxTestUtils: @jet, @test_gradients, check_approx

LuxTestUtils.jet_target_modules!(["Lux", "LuxCore", "LuxLib"])

# Some Helper Functions
function get_default_rng(mode::String)
dev = mode == "cpu" ? LuxCPUDevice() :
Expand Down

0 comments on commit 8c630e4

Please sign in to comment.