Skip to content

Commit

Permalink
Fix old interface returning logits
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Nov 28, 2024
1 parent ad99284 commit edc157e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
@deprecate argmax_sampler(logits; device=identity) Top_k(1)(device(logits))
@deprecate argmax_sampler(logits; device=identity) logits |> device |> Top_k(1) |> logitsample
@deprecate argmax_sampler(; kwargs...) logits -> argmax_sampler(logits; kwargs...)

@deprecate top_pk_sampler(logits; p = 0.5f0, k = 5, device = identity) Top_pk(p, k)(device(logits))
@deprecate top_pk_sampler(logits; p = 0.5f0, k = 5, device = identity) logits |> device |> Top_pk(p, k) |> logitsample
@deprecate top_pk_sampler(; kwargs...) logits -> top_pk_sampler(logits; kwargs...)

@deprecate min_p_sampler(logits; pbase = 0.5f0, device = identity) Min_p(pbase)(device(logits))
@deprecate min_p_sampler(logits; pbase = 0.5f0, device = identity) logits |> device |> Min_p(pbase) |> logitsample
@deprecate min_p_sampler(; kwargs...) logits -> min_p_sampler(logits; kwargs...)

@deprecate top_nσ_sampler(logits; temperature = 1.0f0, n = 1.0f0, device = identity) Top_nσ(temperature, n)(device(logits))
@deprecate top_nσ_sampler(logits; temperature = 1.0f0, n = 1.0f0, device = identity) logits |> device |> Temperature(temperature) |> Top_nσ(n) |> logitsample
@deprecate top_nσ_sampler(; kwargs...) logits -> top_nσ_sampler(logits; kwargs...)
13 changes: 13 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,17 @@ using Random

end

@testset "deprecated.jl" begin
logits = log.([0.1, 0.2, 0.3, 0.4])
@test argmax_sampler(logits) isa Integer
@test_throws BoundsError top_pk_sampler(logits) # k defaults to 5
@test top_pk_sampler(logits; k = 3) isa Integer
@test min_p_sampler(logits) isa Integer
@test top_nσ_sampler(logits) isa Integer
@test argmax_sampler() isa Function
@test top_pk_sampler() isa Function
@test min_p_sampler() isa Function
@test top_nσ_sampler() isa Function
end

end

2 comments on commit edc157e

@AntonOresten
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • Add logitsample function for GPU-friendly weighted sampling in the log domain.
  • Add abstract LogitTransform type.
    • Add Temperate type.
    • Add Top_pk type with additional Top_p and Top_k constructors.
    • Add Min_p type.
    • Add Top_nσ type.
  • Deprecate pre-release interface.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/120312

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.0 -m "<description of version>" edc157eda2b7c8c97e1cfcac155a886b51e86369
git push origin v0.1.0

Please sign in to comment.