From a9a1c44abaf671f06794f6664fe3ca47bcade5e5 Mon Sep 17 00:00:00 2001 From: EtienneDosSantos <130935112+EtienneDosSantos@users.noreply.github.com> Date: Sun, 26 May 2024 15:43:24 +0200 Subject: [PATCH 1/3] Add `"lamb"` to `str2optimizer32bit` --- bitsandbytes/functional.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index f915223ca..dc1490482 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -49,6 +49,10 @@ def prod(iterable): lib.cadagrad32bit_grad_32, lib.cadagrad32bit_grad_16, ), + "lamb": ( + lib.cadam32bit_grad_fp32, + lib.cadam32bit_grad_fp16, + ), } str2optimizer8bit = { From 2e46eefcb214cffc0fb9d6ace71f53924f9c7873 Mon Sep 17 00:00:00 2001 From: EtienneDosSantos <130935112+EtienneDosSantos@users.noreply.github.com> Date: Tue, 28 May 2024 18:35:31 +0200 Subject: [PATCH 2/3] Sorted alphabetically for better overview --- bitsandbytes/functional.py | 64 +++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index dc1490482..0b1e7d5c4 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -27,11 +27,24 @@ def prod(iterable): if lib and lib.compiled_with_cuda: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = { + "adagrad": ( + lib.cadagrad32bit_grad_32, + lib.cadagrad32bit_grad_16, + ), "adam": ( lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16, ), + "lamb": ( + lib.cadam32bit_grad_fp32, + lib.cadam32bit_grad_fp16, + ), + "lion": ( + lib.clion32bit_grad_fp32, + lib.clion32bit_grad_fp16, + lib.clion32bit_grad_bf16, + ), "momentum": ( lib.cmomentum32bit_grad_32, lib.cmomentum32bit_grad_16, @@ -40,19 +53,6 @@ def prod(iterable): lib.crmsprop32bit_grad_32, lib.crmsprop32bit_grad_16, ), - "lion": ( - lib.clion32bit_grad_fp32, - lib.clion32bit_grad_fp16, - lib.clion32bit_grad_bf16, - ), - "adagrad": ( - lib.cadagrad32bit_grad_32, - lib.cadagrad32bit_grad_16, - ), - "lamb": ( - lib.cadam32bit_grad_fp32, - lib.cadam32bit_grad_fp16, - ), } str2optimizer8bit = { @@ -60,34 +60,43 @@ def prod(iterable): lib.cadam_static_8bit_grad_32, lib.cadam_static_8bit_grad_16, ), - "momentum": ( + "lamb": ( + lib.cadam_static_8bit_grad_32, + lib.cadam_static_8bit_grad_16, + ), + "lars": ( lib.cmomentum_static_8bit_grad_32, lib.cmomentum_static_8bit_grad_16, ), - "rmsprop": ( - lib.crmsprop_static_8bit_grad_32, - lib.crmsprop_static_8bit_grad_16, - ), "lion": ( lib.clion_static_8bit_grad_32, lib.clion_static_8bit_grad_16, ), - "lamb": ( - lib.cadam_static_8bit_grad_32, - lib.cadam_static_8bit_grad_16, - ), - "lars": ( + "momentum": ( lib.cmomentum_static_8bit_grad_32, lib.cmomentum_static_8bit_grad_16, ), + "rmsprop": ( + lib.crmsprop_static_8bit_grad_32, + lib.crmsprop_static_8bit_grad_16, + ), } str2optimizer8bit_blockwise = { + "adagrad": ( + lib.cadagrad_8bit_blockwise_grad_fp32, + lib.cadagrad_8bit_blockwise_grad_fp16, + ), "adam": ( lib.cadam_8bit_blockwise_grad_fp32, lib.cadam_8bit_blockwise_grad_fp16, lib.cadam_8bit_blockwise_grad_bf16, ), + "lion": ( + lib.clion_8bit_blockwise_grad_fp32, + lib.clion_8bit_blockwise_grad_fp16, + lib.clion_8bit_blockwise_grad_bf16, + ), "momentum": ( lib.cmomentum_8bit_blockwise_grad_fp32, lib.cmomentum_8bit_blockwise_grad_fp16, @@ -96,15 +105,6 @@ def prod(iterable): lib.crmsprop_8bit_blockwise_grad_fp32, lib.crmsprop_8bit_blockwise_grad_fp16, ), - "lion": ( - lib.clion_8bit_blockwise_grad_fp32, - lib.clion_8bit_blockwise_grad_fp16, - lib.clion_8bit_blockwise_grad_bf16, - ), - "adagrad": ( - lib.cadagrad_8bit_blockwise_grad_fp32, - lib.cadagrad_8bit_blockwise_grad_fp16, - ), } From 7a338db2eccbd60b7da3b7bed9c927117c6b3806 Mon Sep 17 00:00:00 2001 From: EtienneDosSantos <130935112+EtienneDosSantos@users.noreply.github.com> Date: Tue, 28 May 2024 19:53:57 +0200 Subject: [PATCH 3/3] Update functional.py --- bitsandbytes/functional.py | 88 ++++++++++++++++++++++++++++++-------- 1 file changed, 70 insertions(+), 18 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 0b1e7d5c4..bbfbf0007 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -28,57 +28,94 @@ def prod(iterable): """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = { "adagrad": ( - lib.cadagrad32bit_grad_32, - lib.cadagrad32bit_grad_16, + lib.cadagrad32bit_grad_fp32, + lib.cadagrad32bit_grad_fp16, ), "adam": ( lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16, ), + "pagedadam": ( + lib.cpagedadam32bit_grad_fp32, + lib.cpagedadam32bit_grad_fp16, + lib.cpagedadam32bit_grad_bf16, + ), + "adamw": ( + lib.cadam32bit_grad_fp32, + lib.cadam32bit_grad_fp16, + lib.cadam32bit_grad_bf16, + ), + "pagedadamw": ( + lib.cpagedadam32bit_grad_fp32, + lib.cpagedadam32bit_grad_fp16, + lib.cpagedadam32bit_grad_bf16, + ), "lamb": ( lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, ), + "lars": ( + lib.clars32bit_grad_fp32, + lib.clars32bit_grad_fp16, + ), "lion": ( lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16, ), "momentum": ( - lib.cmomentum32bit_grad_32, - lib.cmomentum32bit_grad_16, + lib.cmomentum32bit_grad_fp32, + lib.cmomentum32bit_grad_fp16, ), "rmsprop": ( - lib.crmsprop32bit_grad_32, - lib.crmsprop32bit_grad_16, + lib.crmsprop32bit_grad_fp32, + lib.crmsprop32bit_grad_fp16, ), } str2optimizer8bit = { + "adagrad": ( + lib.cadagrad8bit_grad_fp32, + lib.cadagrad8bit_grad_fp16, + ), "adam": ( - lib.cadam_static_8bit_grad_32, - lib.cadam_static_8bit_grad_16, + lib.cadam_static_8bit_grad_fp32, + lib.cadam_static_8bit_grad_fp16, + ), + "pagedadam": ( + lib.cpagedadam8bit_grad_fp32, + lib.cpagedadam8bit_grad_fp16, + lib.cpagedadam8bit_grad_bf16, + ), + "adamw": ( + lib.cadam_static_8bit_grad_fp32, + lib.cadam_static_8bit_grad_fp16, + ), + "pagedadamw": ( + lib.cpagedadam8bit_grad_fp32, + lib.cpagedadam8bit_grad_fp16, + lib.cpagedadam8bit_grad_bf16, ), "lamb": ( - lib.cadam_static_8bit_grad_32, - lib.cadam_static_8bit_grad_16, + lib.cadam_static_8bit_grad_fp32, + lib.cadam_static_8bit_grad_fp16, ), "lars": ( - lib.cmomentum_static_8bit_grad_32, - lib.cmomentum_static_8bit_grad_16, + lib.clars8bit_grad_fp32, + lib.clars8bit_grad_fp16, ), "lion": ( - lib.clion_static_8bit_grad_32, - lib.clion_static_8bit_grad_16, + lib.clion_static_8bit_grad_fp32, + lib.clion_static_8bit_grad_fp16, ), "momentum": ( - lib.cmomentum_static_8bit_grad_32, - lib.cmomentum_static_8bit_grad_16, + lib.cmomentum_static_8bit_grad_fp32, + lib.cmomentum_static_8bit_grad_fp16, ), "rmsprop": ( - lib.crmsprop_static_8bit_grad_32, - lib.crmsprop_static_8bit_grad_16, + lib.crmsprop_static_8bit_grad_fp32, + lib.crmsprop_static_8bit_grad_fp16, ), } @@ -92,6 +129,21 @@ def prod(iterable): lib.cadam_8bit_blockwise_grad_fp16, lib.cadam_8bit_blockwise_grad_bf16, ), + "pagedadam": ( + lib.cpagedadam8bit_blockwise_fp32, + lib.cpagedadam8bit_blockwise_fp16, + lib.cpagedadam8bit_blockwise_bf16, + ), + "adamw": ( + lib.cadam_8bit_blockwise_grad_fp32, + lib.cadam_8bit_blockwise_grad_fp16, + lib.cadam_8bit_blockwise_grad_bf16, + ), + "pagedadamw": ( + lib.cpagedadam8bit_blockwise_fp32, + lib.cpagedadam8bit_blockwise_fp16, + lib.cpagedadam8bit_blockwise_bf16, + ), "lion": ( lib.clion_8bit_blockwise_grad_fp32, lib.clion_8bit_blockwise_grad_fp16,