Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can RNG be moved to GPUArraysCore? #556

Open
avik-pal opened this issue Aug 28, 2024 · 2 comments
Open

Can RNG be moved to GPUArraysCore? #556

avik-pal opened this issue Aug 28, 2024 · 2 comments

Comments

@avik-pal
Copy link

From the definition of the struct, it seems like it should be as simple as copy-pasting it over. I can open a PR for this, but wanted to know if there was any reason not to do this previously.

Metal and oneAPI default_rng return RNG objects, so it is very useful (in the deep learning context) to define dispatches on RNG to initialize data directly on the device (for weight initialization, functions that need random numbers on runtime, etc.). To do this I end up defining weakdeps like:

  1. https://github.com/LuxDL/WeightInitializers.jl/blob/main/ext/WeightInitializersGPUArraysExt.jl
  2. https://github.com/LuxDL/MLDataDevices.jl/blob/main/ext/MLDataDevicesGPUArraysExt.jl

Moving to GPUArraysCore makes it significantly easier, since that is quite lightweight to not need extensions.

@maleadt
Copy link
Member

maleadt commented Aug 29, 2024

Metal and oneAPI default_rng return RNG objects

Part of the idea here was that users shouldn't have to care about which RNG object you get. In some cases, it may be a native one (Metal.jl nowadays has its own RNG), in other cases the GPUArrays one; it depends on the exact invocation of the relevant Random function (it be in-place or out-of-place). So I'm wondering if this is even the right approach, as you're now always using the slow GPUArrays RNG while the package may support a much faster one. Ideally you wouldn't hard-code the RNG type, but just use whatever the package returns for this operation, if that's an option.

@avik-pal
Copy link
Author

Metal.jl nowadays has its own RNG

Oh I wasn't aware of that. I will add a dispatch for that.

So I'm wondering if this is even the right approach, as you're now always using the slow GPUArrays RNG while the package may support a much faster one.

The examples I linked above are mostly to support "all possible" types of RNGs. For example, if the user passes in CUDA.RNG or the rocRAND ones we will use the faster dispatches. By default, we select the RNG type as returned by default_rng (if present in the package) else fallback to gpuarrays_rng.

The reason to allow GPUArrays.RNG is also for the backends that don't provide a native RNG object (I think oneAPI is one of them currently?). In this case not supporting GPUArrays.RNG means calls to copyto!(::GPUArray, rand!(<Random RNG object>, ::Array)) that tends to be slower.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants