Skip to content

Commit

Permalink
Add Preferences
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Apr 5, 2024
1 parent 31207e1 commit f5111df
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 4 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ uuid = "0979c8fe-16a4-4796-9b82-89a9f10403ea"
authors = ["pedrovalerolara <valerolarap@ornl.gov>", "williamfgc <williamfgc@yahoo.com>"]
version = "0.0.1"

[deps]
Preferences = "21216c6a-2e73-6563-6e65-726566657250"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
5 changes: 4 additions & 1 deletion ext/JACCAMDGPU/JACCAMDGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,10 @@ function reduce_kernel_amdgpu_MN((M, N), red, ret)
end

function __init__()
const JACC.default_backend = ROCBackend()
if JACCPreferences.backend == "amdgpu"
const JACC.default_backend = ROCBackend()
@info "Set default backend to $(JACC.default_backend)"
end
end

end # module JACCAMDGPU
5 changes: 4 additions & 1 deletion ext/JACCCUDA/JACCCUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,10 @@ function reduce_kernel_cuda_MN((M, N), red, ret)
end

function __init__()
const JACC.default_backend = CUDABackend()
if JACC.JACCPreferences.backend == "cuda"
const JACC.default_backend = CUDABackend()
@info "Set default backend to $(JACC.default_backend)"
end
end

end # module JACCCUDA
5 changes: 4 additions & 1 deletion ext/JACCONEAPI/JACCONEAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,10 @@ function reduce_kernel_oneapi_MN((M, N), red, ret)
end

function __init__()
const JACC.default_backend = oneAPIBackend()
if JACC.JACCPreferences.backend == "oneapi"
const JACC.default_backend = oneAPIBackend()
@info "Set default backend to $(JACC.default_backend)"
end
end

end # module JACCONEAPI
6 changes: 5 additions & 1 deletion src/JACC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module JACC

# module to set back end preferences
include("helper.jl")
include("JACCPreferences.jl")

export parallel_for, parallel_reduce, ThreadsBackend, print_default_backend

Expand Down Expand Up @@ -69,7 +70,10 @@ function parallel_reduce(::ThreadsBackend, (M, N)::Tuple{I,I}, f::F, x...) where
end

function __init__()
const JACC.default_backend = ThreadsBackend()
if JACCPreferences.backend == "threads"
const JACC.default_backend = ThreadsBackend()
@info "Set default backend to $(JACC.default_backend)"
end
end

function print_default_backend()
Expand Down
22 changes: 22 additions & 0 deletions src/JACCPreferences.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

module JACCPreferences

using Preferences

# taken from https://github.com/JuliaPackaging/Preferences.jl
function set_backend(new_backend::String)

new_backend_lc = lowercase(new_backend)
if !(new_backend_lc in ("threads", "cuda", "amdgpu", "oneapi"))
throw(ArgumentError("Invalid backend: \"$(new_backend)\""))
end

# Set it in our runtime values, as well as saving it to disk
@set_preferences!("backend" => new_backend_lc)
@info("New backend set; restart your Julia session for this change to take effect!")
end

const backend = @load_preference("backend", "threads")


end # module JACCPreferences

0 comments on commit f5111df

Please sign in to comment.