From 7d4f50b11a02ace348f2c02bea97fa329cd4bdb4 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Thu, 28 Nov 2024 19:04:34 +0100 Subject: [PATCH] Un-deprecate sampler interface --- src/LogitSamplers.jl | 2 +- src/deprecated.jl | 11 ----------- src/samplers.jl | 11 +++++++++++ 3 files changed, 12 insertions(+), 12 deletions(-) delete mode 100644 src/deprecated.jl create mode 100644 src/samplers.jl diff --git a/src/LogitSamplers.jl b/src/LogitSamplers.jl index e549d77..00a7f1f 100644 --- a/src/LogitSamplers.jl +++ b/src/LogitSamplers.jl @@ -16,7 +16,7 @@ export Top_pk, Top_p, Top_k export Min_p export Top_nσ -include("deprecated.jl") +include("samplers.jl") export argmax_sampler, top_pk_sampler, min_p_sampler, top_nσ_sampler end diff --git a/src/deprecated.jl b/src/deprecated.jl deleted file mode 100644 index fe63845..0000000 --- a/src/deprecated.jl +++ /dev/null @@ -1,11 +0,0 @@ -@deprecate argmax_sampler(logits; device=identity) logits |> device |> Top_k(1) |> logitsample -@deprecate argmax_sampler(; kwargs...) logits -> argmax_sampler(logits; kwargs...) - -@deprecate top_pk_sampler(logits; p = 0.5f0, k = 5, device = identity) logits |> device |> Top_pk(p, k) |> logitsample -@deprecate top_pk_sampler(; kwargs...) logits -> top_pk_sampler(logits; kwargs...) - -@deprecate min_p_sampler(logits; pbase = 0.5f0, device = identity) logits |> device |> Min_p(pbase) |> logitsample -@deprecate min_p_sampler(; kwargs...) logits -> min_p_sampler(logits; kwargs...) - -@deprecate top_nσ_sampler(logits; temperature = 1.0f0, n = 1.0f0, device = identity) logits |> device |> Temperature(temperature) |> Top_nσ(n) |> logitsample -@deprecate top_nσ_sampler(; kwargs...) logits -> top_nσ_sampler(logits; kwargs...) diff --git a/src/samplers.jl b/src/samplers.jl new file mode 100644 index 0000000..23a0e85 --- /dev/null +++ b/src/samplers.jl @@ -0,0 +1,11 @@ +argmax_sampler(logits; device=identity) = logits |> device |> Top_k(1) |> logitsample +argmax_sampler(; kwargs...) = logits -> argmax_sampler(logits; kwargs...) + +top_pk_sampler(logits; p = 0.5f0, k = 5, device = identity) = logits |> device |> Top_pk(p, k) |> logitsample +top_pk_sampler(; kwargs...) = logits -> top_pk_sampler(logits; kwargs...) + +min_p_sampler(logits; pbase = 0.5f0, device = identity) = logits |> device |> Min_p(pbase) |> logitsample +min_p_sampler(; kwargs...) = logits -> min_p_sampler(logits; kwargs...) + +top_nσ_sampler(logits; temperature = 1.0f0, n = 1.0f0, device = identity) = logits |> device |> Temperature(temperature) |> Top_nσ(n) |> logitsample +top_nσ_sampler(; kwargs...) = logits -> top_nσ_sampler(logits; kwargs...)