From edc157eda2b7c8c97e1cfcac155a886b51e86369 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Thu, 28 Nov 2024 18:15:53 +0100 Subject: [PATCH] Fix old interface returning logits --- src/deprecated.jl | 8 ++++---- test/runtests.jl | 13 +++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/deprecated.jl b/src/deprecated.jl index 2a04925..fe63845 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -1,11 +1,11 @@ -@deprecate argmax_sampler(logits; device=identity) Top_k(1)(device(logits)) +@deprecate argmax_sampler(logits; device=identity) logits |> device |> Top_k(1) |> logitsample @deprecate argmax_sampler(; kwargs...) logits -> argmax_sampler(logits; kwargs...) -@deprecate top_pk_sampler(logits; p = 0.5f0, k = 5, device = identity) Top_pk(p, k)(device(logits)) +@deprecate top_pk_sampler(logits; p = 0.5f0, k = 5, device = identity) logits |> device |> Top_pk(p, k) |> logitsample @deprecate top_pk_sampler(; kwargs...) logits -> top_pk_sampler(logits; kwargs...) -@deprecate min_p_sampler(logits; pbase = 0.5f0, device = identity) Min_p(pbase)(device(logits)) +@deprecate min_p_sampler(logits; pbase = 0.5f0, device = identity) logits |> device |> Min_p(pbase) |> logitsample @deprecate min_p_sampler(; kwargs...) logits -> min_p_sampler(logits; kwargs...) -@deprecate top_nσ_sampler(logits; temperature = 1.0f0, n = 1.0f0, device = identity) Top_nσ(temperature, n)(device(logits)) +@deprecate top_nσ_sampler(logits; temperature = 1.0f0, n = 1.0f0, device = identity) logits |> device |> Temperature(temperature) |> Top_nσ(n) |> logitsample @deprecate top_nσ_sampler(; kwargs...) logits -> top_nσ_sampler(logits; kwargs...) diff --git a/test/runtests.jl b/test/runtests.jl index d0afcb1..d740683 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -140,4 +140,17 @@ using Random end + @testset "deprecated.jl" begin + logits = log.([0.1, 0.2, 0.3, 0.4]) + @test argmax_sampler(logits) isa Integer + @test_throws BoundsError top_pk_sampler(logits) # k defaults to 5 + @test top_pk_sampler(logits; k = 3) isa Integer + @test min_p_sampler(logits) isa Integer + @test top_nσ_sampler(logits) isa Integer + @test argmax_sampler() isa Function + @test top_pk_sampler() isa Function + @test min_p_sampler() isa Function + @test top_nσ_sampler() isa Function + end + end