From d5e8545b97ca76e3ee75509dfe7087e05d66a11f Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Thu, 4 Apr 2024 19:53:08 -0500 Subject: [PATCH] Add Preferences --- Project.toml | 4 ++++ ext/JACCAMDGPU/JACCAMDGPU.jl | 5 ++++- ext/JACCCUDA/JACCCUDA.jl | 5 ++++- ext/JACCONEAPI/JACCONEAPI.jl | 5 ++++- src/JACC.jl | 6 +++++- src/JACCPreferences.jl | 22 ++++++++++++++++++++++ test/tests_threads_perf.jl | 2 +- 7 files changed, 44 insertions(+), 5 deletions(-) create mode 100644 src/JACCPreferences.jl diff --git a/Project.toml b/Project.toml index 0488902..810a11f 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,9 @@ uuid = "0979c8fe-16a4-4796-9b82-89a9f10403ea" authors = ["pedrovalerolara ", "williamfgc "] version = "0.0.1" +[deps] +Preferences = "21216c6a-2e73-6563-6e65-726566657250" + [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -14,6 +17,7 @@ JACCCUDA = ["CUDA"] JACCONEAPI = ["oneAPI"] [compat] +Preferences = "1.4.0" julia = "1.9.0" [extras] diff --git a/ext/JACCAMDGPU/JACCAMDGPU.jl b/ext/JACCAMDGPU/JACCAMDGPU.jl index 35db2a3..6418af5 100644 --- a/ext/JACCAMDGPU/JACCAMDGPU.jl +++ b/ext/JACCAMDGPU/JACCAMDGPU.jl @@ -300,7 +300,10 @@ function reduce_kernel_amdgpu_MN((M, N), red, ret) end function __init__() - const JACC.default_backend = ROCBackend() + if JACC.JACCPreferences.backend == "amdgpu" + const JACC.default_backend = ROCBackend() + @info "Set default backend to $(JACC.default_backend)" + end end end # module JACCAMDGPU diff --git a/ext/JACCCUDA/JACCCUDA.jl b/ext/JACCCUDA/JACCCUDA.jl index 2a0d13b..003ff9e 100644 --- a/ext/JACCCUDA/JACCCUDA.jl +++ b/ext/JACCCUDA/JACCCUDA.jl @@ -294,7 +294,10 @@ function reduce_kernel_cuda_MN((M, N), red, ret) end function __init__() - const JACC.default_backend = CUDABackend() + if JACC.JACCPreferences.backend == "cuda" + const JACC.default_backend = CUDABackend() + @info "Set default backend to $(JACC.default_backend)" + end end end # module JACCCUDA diff --git a/ext/JACCONEAPI/JACCONEAPI.jl b/ext/JACCONEAPI/JACCONEAPI.jl index 6122a27..215a2a2 100644 --- a/ext/JACCONEAPI/JACCONEAPI.jl +++ b/ext/JACCONEAPI/JACCONEAPI.jl @@ -294,7 +294,10 @@ function reduce_kernel_oneapi_MN((M, N), red, ret) end function __init__() - const JACC.default_backend = oneAPIBackend() + if JACC.JACCPreferences.backend == "oneapi" + const JACC.default_backend = oneAPIBackend() + @info "Set default backend to $(JACC.default_backend)" + end end end # module JACCONEAPI diff --git a/src/JACC.jl b/src/JACC.jl index b73ee9b..2864a1e 100644 --- a/src/JACC.jl +++ b/src/JACC.jl @@ -1,6 +1,7 @@ module JACC # module to set back end preferences +include("JACCPreferences.jl") include("helper.jl") export parallel_for, parallel_reduce, ThreadsBackend, print_default_backend @@ -69,7 +70,10 @@ function parallel_reduce(::ThreadsBackend, (M, N)::Tuple{I,I}, f::F, x...) where end function __init__() - const JACC.default_backend = ThreadsBackend() + if JACCPreferences.backend == "threads" + const JACC.default_backend = ThreadsBackend() + @info "Set default backend to $(JACC.default_backend)" + end end function print_default_backend() diff --git a/src/JACCPreferences.jl b/src/JACCPreferences.jl new file mode 100644 index 0000000..13643fd --- /dev/null +++ b/src/JACCPreferences.jl @@ -0,0 +1,22 @@ + +module JACCPreferences + +using Preferences + +# taken from https://github.com/JuliaPackaging/Preferences.jl +function set_backend(new_backend::String) + + new_backend_lc = lowercase(new_backend) + if !(new_backend_lc in ("threads", "cuda", "amdgpu", "oneapi")) + throw(ArgumentError("Invalid backend: \"$(new_backend)\"")) + end + + # Set it in our runtime values, as well as saving it to disk + @set_preferences!("backend" => new_backend_lc) + @info("New backend set; restart your Julia session for this change to take effect!") +end + +const backend = @load_preference("backend", "threads") + + +end # module JACCPreferences diff --git a/test/tests_threads_perf.jl b/test/tests_threads_perf.jl index 508a6b2..41279a1 100644 --- a/test/tests_threads_perf.jl +++ b/test/tests_threads_perf.jl @@ -35,4 +35,4 @@ using Test for i in 1:ntimes @time JACC.parallel_for(ThreadsBackend(), N, axpy, alpha, x_JACC, y_JACC) end -end \ No newline at end of file +end