Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: easy mechanism to set preferences #798

Merged
merged 2 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions LocalPreferences.toml

This file was deleted.

4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down Expand Up @@ -75,6 +76,7 @@ Compat = "4.15"
ComponentArrays = "0.15.11"
ConcreteStructs = "0.2.3"
ConstructionBase = "1.5"
DispatchDoctor = "0.4.12"
Documenter = "1.4"
DynamicExpressions = "0.16, 0.17, 0.18"
Enzyme = "0.12.24"
Expand All @@ -95,7 +97,7 @@ LossFunctions = "0.11.1"
LuxCore = "0.1.16"
LuxDeviceUtils = "0.1.26"
LuxLib = "0.3.33"
LuxTestUtils = "0.1.15"
LuxTestUtils = "0.1.18"
MLUtils = "0.4.3"
MPI = "0.20.19"
MacroTools = "0.5.13"
Expand Down
1 change: 0 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ makedocs(; sitename="Lux.jl Docs",
repo="github.com/LuxDL/Lux.jl", devbranch="main", devurl="dev",
deploy_url="https://lux.csail.mit.edu", deploy_decision),
draft=false,
warnonly=:linkcheck, # Lately it has been failing quite a lot but those links are actually fine
pages)

deploydocs(; repo="github.com/LuxDL/Lux.jl.git",
Expand Down
6 changes: 6 additions & 0 deletions docs/src/api/Lux/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ StatefulLuxLayer
@non_trainable
```

## Preferences

```@docs
Lux.set_dispatch_doctor_preferences!
```

## Truncated Stacktraces (Deprecated)

```@docs
Expand Down
2 changes: 1 addition & 1 deletion docs/src/manual/autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Lux. Additionally, we provide some convenience functions for working with AD.
| [`ForwardDiff.jl`](https://github.com/JuliaDiff/ForwardDiff.jl) | Forward | ✔️ | ✔️ | ✔️ | Tier I |
| [`ReverseDiff.jl`](https://github.com/JuliaDiff/ReverseDiff.jl) | Reverse | ✔️ ||| Tier II |
| [`Tracker.jl`](https://github.com/FluxML/Tracker.jl) | Reverse | ✔️ | ✔️ || Tier II |
| [`Tapir.jl`](https://github.com/withbayes/Tapir.jl) | Reverse |[^q] ||| Tier III |
| [`Tapir.jl`](https://github.com/compintell/Tapir.jl) | Reverse |[^q] ||| Tier III |
| [`Diffractor.jl`](https://github.com/JuliaDiff/Diffractor.jl) | Forward |[^q] |[^q] |[^q] | Tier III |

[^e]: Currently Enzyme outperforms other AD packages in terms of CPU performance. However,
Expand Down
6 changes: 3 additions & 3 deletions docs/src/manual/nested_autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ nothing; # hide

Hutchinson Trace Estimation often shows up in machine learning literature to provide a fast
estimate of the trace of a Jacobian Matrix. This is based off of
[Hutchinson 1990](https://www.researchgate.net/publication/243668757_A_Stochastic_Estimator_of_the_Trace_of_the_Influence_Matrix_for_Laplacian_Smoothing_Splines) which
computes the estimated trace of a matrix ``A \in \mathbb{R}^{D \times D}`` using random
vectors ``v \in \mathbb{R}^{D}`` s.t. ``\mathbb{E}\left[v v^T\right] = I``.
[Hutchinson 1990](https://www.nowozin.net/sebastian/blog/thoughts-on-trace-estimation-in-deep-learning.html)
which computes the estimated trace of a matrix ``A \in \mathbb{R}^{D \times D}`` using
random vectors ``v \in \mathbb{R}^{D}`` s.t. ``\mathbb{E}\left[v v^T\right] = I``.

```math
\text{Tr}(A) = \mathbb{E}\left[v^T A v\right] = \frac{1}{V} \sum_{i = 1}^V v_i^T A v_i
Expand Down
7 changes: 7 additions & 0 deletions docs/src/manual/performance_pitfalls.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,10 @@ using:
using GPUArraysCore
GPUArraysCore.allowscalar(false)
```

## Type Instabilities

`Lux.jl` is integrated with `DispatchDoctor.jl` to catch type instabilities. You can easily
enable it by setting the `instability_check` preference. This will help you catch type
instabilities in your code. For more information on how to set preferences, check out
[`Lux.set_dispatch_doctor_preferences!`](@ref).
11 changes: 11 additions & 0 deletions docs/src/manual/preferences.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,14 @@ By default, both of these preferences are set to `false`.
1. `eltype_mismatch_handling` - Preference controlling what happens when layers get
different eltypes as input. See the documentation on [`match_eltype`](@ref) for more
details.

## [Dispatch Doctor](@id dispatch-doctor-preference)

1. `instability_check` - Preference controlling the dispatch doctor. See the documentation
on [`Lux.set_dispatch_doctor_preferences!`](@ref) for more details. The preferences need
to be set for `LuxCore` and `LuxLib` packages. Both of them default to `disable`.
- Setting the `LuxCore` preference sets the check at the level of `LuxCore.apply`. This
essentially activates the dispatch doctor for all Lux layers.
- Setting the `LuxLib` preference sets the check at the level of functional layer of
Lux, for example, [`fused_dense_bias_activation`](@ref). These functions are supposed
to be type stable for common input types and can be used to guarantee type stability.
2 changes: 1 addition & 1 deletion examples/Basics/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This is a quick intro to [Lux](https://github.com/LuxDL/Lux.jl) loosely based on:
#
# 1. [PyTorch's tutorial](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).
# 2. [Flux's tutorial](https://fluxml.ai/Flux.jl/stable/tutorials/2020-09-15-deep-learning-flux/).
# 2. Flux's tutorial (the link for which has now been lost to abyss).
# 3. [Jax's tutorial](https://jax.readthedocs.io/en/latest/jax-101/index.html).
#
# It introduces basic Julia programming, as well `Zygote`, a source-to-source automatic
Expand Down
10 changes: 6 additions & 4 deletions examples/BayesianNN/main.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# # Bayesian Neural Network

# We borrow this tutorial from the
# [official Turing Docs](https://turinglang.org/stable/tutorials/03-bayesian-neural-network/). We
# will show how the explicit parameterization of Lux enables first-class composability with
# packages which expect flattened out parameter vectors.
# [official Turing Docs](https://turinglang.org/docs/tutorials/03-bayesian-neural-network/index.html).
# We will show how the explicit parameterization of Lux enables first-class composability
# with packages which expect flattened out parameter vectors.

# We will use [Turing.jl](https://turinglang.org/stable/) with [Lux.jl](https://lux.csail.mit.edu/)
# Note: The tutorial in the official Turing docs is now using Lux instead of Flux.

# We will use [Turing.jl](https://turinglang.org/) with [Lux.jl](https://lux.csail.mit.edu/)
# to implement implementing a classification algorithm. Lets start by importing the relevant
# libraries.

Expand Down
4 changes: 2 additions & 2 deletions examples/SymbolicOptimalControl/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

# This tutorial is based on [SciMLSensitivity.jl tutorial](https://docs.sciml.ai/SciMLSensitivity/stable/examples/optimal_control/optimal_control/).
# Instead of using a classical NN architecture, here we will combine the NN with a symbolic
# expression from [DynamicExpressions.jl](https://symbolicml.org/DynamicExpressions.jl) (the
# symbolic engine behind [SymbolicRegression.jl](https://astroautomata.com/SymbolicRegression.jl)
# expression from [DynamicExpressions.jl](https://symbolicml.org/DynamicExpressions.jl/) (the
# symbolic engine behind [SymbolicRegression.jl](https://astroautomata.com/SymbolicRegression.jl/)
# and [PySR](https://github.com/MilesCranmer/PySR/)).

# Here we will solve a classic optimal control problem with a universal differential
Expand Down
3 changes: 2 additions & 1 deletion src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using MacroTools: MacroTools, block, combinedef, splitdef
using Markdown: @doc_str
using NNlib: NNlib
using Optimisers: Optimisers
using Preferences: load_preference, has_preference
using Preferences: load_preference, has_preference, set_preferences!
using Random: Random, AbstractRNG
using Reexport: @reexport
using Statistics: mean
Expand Down Expand Up @@ -133,6 +133,7 @@ export MPIBackend, NCCLBackend, DistributedUtils
# Unexported functions that are part of the public API
@compat public Experimental
@compat public xlogx, xlogy
@compat public set_dispatch_doctor_preferences!
@compat(public,
(recursive_add!!, recursive_copyto!, recursive_eltype,
recursive_make_zero, recursive_map, recursive_make_zero!!))
Expand Down
2 changes: 1 addition & 1 deletion src/helpers/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ true
## Special Note
This function takes any of the
[`LossFunctions.jl`](https://juliaml.github.io/LossFunctions.jl/stable) public functions
[`LossFunctions.jl`](https://juliaml.github.io/LossFunctions.jl/stable/) public functions
into the Lux Losses API with efficient aggregation.
"""
@concrete struct GenericLossFunction <: AbstractLossFunction
Expand Down
39 changes: 39 additions & 0 deletions src/preferences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,42 @@ const MPI_ROCM_AWARE = @deprecate_preference("LuxDistributedMPIROCMAware", "rocm
# Eltype Auto Conversion
const ELTYPE_MISMATCH_HANDLING = @load_preference_with_choices("eltype_mismatch_handling",
"none", ("none", "warn", "convert", "error"))

# Dispatch Doctor
"""
set_dispatch_doctor_preferences!(mode::String)
set_dispatch_doctor_preferences!(; luxcore::String="disable", luxlib::String="disable")

Set the dispatch doctor preference for `LuxCore` and `LuxLib` packages.

`mode` can be `"disable"`, `"warn"`, or `"error"`. For details on the different modes, see
the [DispatchDoctor.jl](https://astroautomata.com/DispatchDoctor.jl/dev/) documentation.

If the preferences are already set, then no action is taken. Otherwise the preference is
set. For changes to take effect, the Julia session must be restarted.
"""
function set_dispatch_doctor_preferences!(mode::String)
return set_dispatch_doctor_preferences!(; luxcore=mode, luxlib=mode)
end

function set_dispatch_doctor_preferences!(;
luxcore::String="disable", luxlib::String="disable")
_set_dispatch_doctor_preferences!(LuxCore, luxcore)
_set_dispatch_doctor_preferences!(LuxLib, luxlib)
return
end

function _set_dispatch_doctor_preferences!(package, mode::String)
@argcheck mode in ("disable", "warn", "error")
if has_preference(package, "dispatch_doctor")
orig_pref = load_preference(package, "dispatch_doctor")
if orig_pref == mode
@info "Dispatch Doctor preference for $(package) is already set to $mode."
return
end
end
set_preferences!(package, "instability_check" => mode; force=true)
@info "Dispatch Doctor preference for $(package) set to $mode. Please restart Julia \
for this change to take effect."
return
end
23 changes: 23 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,26 @@ if ("all" in LUX_TEST_GROUP || "eltype_match" in LUX_TEST_GROUP)
Test.@test true
end
end

# Set preferences tests
if ("all" in LUX_TEST_GROUP || "others" in LUX_TEST_GROUP)
@testset "DispatchDoctor Preferences" begin
@testset "set_dispatch_doctor_preferences!" begin
@test_throws ArgumentError Lux.set_dispatch_doctor_preferences!("invalid")
@test_throws ArgumentError Lux.set_dispatch_doctor_preferences!(;
luxcore="invalid")

Lux.set_dispatch_doctor_preferences!("disable")
@test Preferences.load_preference(LuxCore, "instability_check") == "disable"
@test Preferences.load_preference(LuxLib, "instability_check") == "disable"

Lux.set_dispatch_doctor_preferences!(; luxcore="warn", luxlib="error")
@test Preferences.load_preference(LuxCore, "instability_check") == "warn"
@test Preferences.load_preference(LuxLib, "instability_check") == "error"

Lux.set_dispatch_doctor_preferences!(; luxcore="error")
@test Preferences.load_preference(LuxCore, "instability_check") == "error"
@test Preferences.load_preference(LuxLib, "instability_check") == "disable"
end
end
end
2 changes: 2 additions & 0 deletions test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ using Lux, Functors
Zygote, Statistics
using LuxTestUtils: @jet, @test_gradients, check_approx

LuxTestUtils.jet_target_modules!(["Lux", "LuxCore", "LuxLib"])

# Some Helper Functions
function get_default_rng(mode::String)
dev = mode == "cpu" ? LuxCPUDevice() :
Expand Down
Loading