From 6a4142a584a3c4c299f4f1f57af587d5c28cdbad Mon Sep 17 00:00:00 2001 From: chengchangxu Date: Thu, 5 Dec 2024 17:55:00 +0800 Subject: [PATCH 1/7] support BF16 --- Project.toml | 2 ++ src/DLPack.jl | 3 +++ 2 files changed, 5 insertions(+) diff --git a/Project.toml b/Project.toml index 5167ab2..b2e444d 100644 --- a/Project.toml +++ b/Project.toml @@ -4,9 +4,11 @@ authors = ["Pablo Zubieta"] version = "0.1.1" [deps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" Requires = "ae029012-a4dd-5104-9daa-d747884805df" [compat] +BFloat16s = "0.5.0" CUDA = "≥ 1.3" PyCall = "1.92" PythonCall = "0.6.1" diff --git a/src/DLPack.jl b/src/DLPack.jl index 4178375..6f93bdd 100644 --- a/src/DLPack.jl +++ b/src/DLPack.jl @@ -18,6 +18,7 @@ module DLPack ## Dependencies ## using Requires +using BFloat16s: BFloat16 ## Exports ## @@ -280,6 +281,7 @@ jltypes_to_dtypes() = Dict( UInt16 => DLDataType(kDLUInt, 16, 1), UInt32 => DLDataType(kDLUInt, 32, 1), UInt64 => DLDataType(kDLUInt, 64, 1), + BFloat16 => DLDataType(kDLBfloat, 16, 1), Float16 => DLDataType(kDLFloat, 16, 1), Float32 => DLDataType(kDLFloat, 32, 1), Float64 => DLDataType(kDLFloat, 64, 1), @@ -301,6 +303,7 @@ dtypes_to_jltypes() = Dict( DLDataType(kDLUInt, 16, 1) => UInt16, DLDataType(kDLUInt, 32, 1) => UInt32, DLDataType(kDLUInt, 64, 1) => UInt64, + DLDataType(kDLBfloat, 16, 1) => BFloat16, DLDataType(kDLFloat, 16, 1) => Float16, DLDataType(kDLFloat, 32, 1) => Float32, DLDataType(kDLFloat, 64, 1) => Float64, From 7d88d40c12d18cff0fee362d08c88bc38286d188 Mon Sep 17 00:00:00 2001 From: chengchangxu Date: Thu, 5 Dec 2024 17:56:41 +0800 Subject: [PATCH 2/7] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1e53346..149ea03 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DLPack" uuid = "53c2dc0f-f7d5-43fd-8906-6c0220547083" authors = ["Pablo Zubieta"] -version = "0.3.0" +version = "0.3.1" [deps] BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" From 5f2aff5197be4171ea2118c157cabe46c53e43e2 Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Sat, 28 Dec 2024 21:56:14 -0600 Subject: [PATCH 3/7] Apply suggestions from code review --- Project.toml | 2 +- src/DLPack.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 149ea03..3a3505f 100644 --- a/Project.toml +++ b/Project.toml @@ -18,8 +18,8 @@ PyCallExt = "PyCall" PythonCallExt = "PythonCall" [compat] -BFloat16s = "0.5.0" Aqua = "0.8" +BFloat16s = "0.4" CUDA = "≥ 1.3" PyCall = "1.92" PythonCall = "≥ 0.7" diff --git a/src/DLPack.jl b/src/DLPack.jl index 9079240..fe741bc 100644 --- a/src/DLPack.jl +++ b/src/DLPack.jl @@ -17,8 +17,8 @@ module DLPack ## Dependencies ## -using Requires using BFloat16s: BFloat16 +using Requires ## Types ## From fb2ce81e7437e10681c2bd3e6488198c9c02f0ae Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Sat, 28 Dec 2024 22:52:08 -0600 Subject: [PATCH 4/7] Set BFloat16s compat as lower bound --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3a3505f..86c9071 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,7 @@ PythonCallExt = "PythonCall" [compat] Aqua = "0.8" -BFloat16s = "0.4" +BFloat16s = "≥ 0.4" CUDA = "≥ 1.3" PyCall = "1.92" PythonCall = "≥ 0.7" From b9330ea65a0f0a33158b13e96050ffc97d5b97f2 Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Sun, 29 Dec 2024 00:13:14 -0600 Subject: [PATCH 5/7] Restrict numpy version until we add version 1 ABI support --- test/setup_python_env.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/setup_python_env.jl b/test/setup_python_env.jl index 73a7a9b..7f0a9ba 100644 --- a/test/setup_python_env.jl +++ b/test/setup_python_env.jl @@ -11,7 +11,7 @@ else "jax", ] end -push!(python_deps, "pytorch", "setuptools<70") +push!(python_deps, "numpy<2.1", "pytorch", "setuptools<70") CondaPkg.add(CondaPkg.PkgREPL.parse_pkg.(python_deps)) From 74bc107e1b2299b84684435137a118e6f24f172a Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Sun, 29 Dec 2024 00:57:50 -0600 Subject: [PATCH 6/7] Add testing for 1.10 (lts) --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 91278f2..53e4bbe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,6 +20,7 @@ jobs: matrix: version: - '1.6' + - '1.10' - '1' - 'nightly' os: From 0ebb6846ff3b9e353d5bde3153508a728808bd21 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Mon, 30 Dec 2024 12:52:29 +0800 Subject: [PATCH 7/7] add restriction to jax version --- test/setup_python_env.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/setup_python_env.jl b/test/setup_python_env.jl index 7f0a9ba..53cbc17 100644 --- a/test/setup_python_env.jl +++ b/test/setup_python_env.jl @@ -8,7 +8,7 @@ python_deps = if VERSION == v"1.6.7" ] else [ - "jax", + "jax<0.4", ] end push!(python_deps, "numpy<2.1", "pytorch", "setuptools<70")