diff --git a/Project.toml b/Project.toml index 0c081ab4..10477dbd 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,6 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -69,7 +68,7 @@ Pkg = "1.10" Preferences = "1.4" Random = "1.10" ReTestItems = "1.23.1" -Reexport = "1" +Reexport = "1.2.2" ReverseDiff = "1.15" SLEEFPirates = "0.6.43" Setfield = "1.1.1" @@ -96,6 +95,7 @@ LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Preferences = "21216c6a-2e73-6563-6e65-726566657250" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -104,4 +104,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "JLArrays", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"] +test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "JLArrays", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "Reexport", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"] diff --git a/src/LuxLib.jl b/src/LuxLib.jl index 766f7642..ddc7d15f 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -16,15 +16,11 @@ using MLDataDevices: get_device_type, AMDGPUDevice, CUDADevice, CPUDevice, using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, σ, ∇conv_data, ∇conv_filter using Octavian: Octavian using Random: Random, AbstractRNG, rand! -using Reexport: @reexport -using Setfield: @set! using StaticArraysCore: StaticArraysCore, StaticArray, StaticVector using Statistics: Statistics, mean, var using SLEEFPirates: SLEEFPirates using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter, unrolled_mapreduce -@reexport using NNlib - const CRC = ChainRulesCore const KA = KernelAbstractions diff --git a/test/common_ops/activation_tests.jl b/test/common_ops/activation_tests.jl index 803abee5..15b5e534 100644 --- a/test/common_ops/activation_tests.jl +++ b/test/common_ops/activation_tests.jl @@ -1,4 +1,6 @@ @testitem "Activation Functions" tags=[:other_ops] setup=[SharedTestSetup] begin + using NNlib + rng = StableRNG(1234) apply_act(f::F, x) where {F} = sum(abs2, f.(x)) diff --git a/test/common_ops/dense_tests.jl b/test/common_ops/dense_tests.jl index b2a0f065..38e9802f 100644 --- a/test/common_ops/dense_tests.jl +++ b/test/common_ops/dense_tests.jl @@ -101,7 +101,7 @@ end end @testitem "Fused Dense: StaticArrays" tags=[:dense] begin - using StaticArrays + using StaticArrays, NNlib x = @SArray rand(2, 4) weight = @SArray rand(3, 2) @@ -111,7 +111,7 @@ end end @testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin - using JLArrays + using JLArrays, NNlib x = JLArray(rand(Float32, 2, 4)) weight = JLArray(rand(Float32, 3, 2)) diff --git a/test/others/forwarddiff_tests.jl b/test/others/forwarddiff_tests.jl index 23c279e8..c75a50c7 100644 --- a/test/others/forwarddiff_tests.jl +++ b/test/others/forwarddiff_tests.jl @@ -1,5 +1,5 @@ @testitem "Efficient JVPs" tags=[:others] setup=[SharedTestSetup] begin - using ForwardDiff, Zygote, ComponentArrays + using ForwardDiff, Zygote, ComponentArrays, NNlib using LuxTestUtils: check_approx # Computes (∂f/∂x)u diff --git a/test/others/qa_tests.jl b/test/others/qa_tests.jl index b00fa347..11c6d5a4 100644 --- a/test/others/qa_tests.jl +++ b/test/others/qa_tests.jl @@ -1,5 +1,5 @@ @testitem "Aqua: Quality Assurance" tags=[:others] begin - using Aqua, ChainRulesCore, EnzymeCore + using Aqua, ChainRulesCore, EnzymeCore, NNlib using EnzymeCore: EnzymeRules Aqua.test_all(LuxLib; ambiguities=false, piracies=false)