diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 91278f2..53e4bbe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,6 +20,7 @@ jobs: matrix: version: - '1.6' + - '1.10' - '1' - 'nightly' os: diff --git a/Project.toml b/Project.toml index ce8f6fe..86c9071 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "DLPack" uuid = "53c2dc0f-f7d5-43fd-8906-6c0220547083" authors = ["Pablo Zubieta"] -version = "0.3.0" +version = "0.3.1" [deps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" Requires = "ae029012-a4dd-5104-9daa-d747884805df" [weakdeps] @@ -18,6 +19,7 @@ PythonCallExt = "PythonCall" [compat] Aqua = "0.8" +BFloat16s = "≥ 0.4" CUDA = "≥ 1.3" PyCall = "1.92" PythonCall = "≥ 0.7" diff --git a/src/DLPack.jl b/src/DLPack.jl index 9b4eee3..fe741bc 100644 --- a/src/DLPack.jl +++ b/src/DLPack.jl @@ -17,6 +17,7 @@ module DLPack ## Dependencies ## +using BFloat16s: BFloat16 using Requires @@ -318,6 +319,7 @@ jltypes_to_dtypes() = Dict( UInt16 => DLDataType(kDLUInt, 16, 1), UInt32 => DLDataType(kDLUInt, 32, 1), UInt64 => DLDataType(kDLUInt, 64, 1), + BFloat16 => DLDataType(kDLBfloat, 16, 1), Float16 => DLDataType(kDLFloat, 16, 1), Float32 => DLDataType(kDLFloat, 32, 1), Float64 => DLDataType(kDLFloat, 64, 1), @@ -340,6 +342,7 @@ dtypes_to_jltypes() = Dict( DLDataType(kDLUInt, 16, 1) => UInt16, DLDataType(kDLUInt, 32, 1) => UInt32, DLDataType(kDLUInt, 64, 1) => UInt64, + DLDataType(kDLBfloat, 16, 1) => BFloat16, DLDataType(kDLFloat, 16, 1) => Float16, DLDataType(kDLFloat, 32, 1) => Float32, DLDataType(kDLFloat, 64, 1) => Float64, diff --git a/test/setup_python_env.jl b/test/setup_python_env.jl index 73a7a9b..53cbc17 100644 --- a/test/setup_python_env.jl +++ b/test/setup_python_env.jl @@ -8,10 +8,10 @@ python_deps = if VERSION == v"1.6.7" ] else [ - "jax", + "jax<0.4", ] end -push!(python_deps, "pytorch", "setuptools<70") +push!(python_deps, "numpy<2.1", "pytorch", "setuptools<70") CondaPkg.add(CondaPkg.PkgREPL.parse_pkg.(python_deps))