diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c8bb533..76f1cc3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -24,13 +24,11 @@ jobs: matrix: version: - '1.10' - - '1.9' - 'pre' os: - ubuntu-latest arch: - x64 - - x86 steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/.gitignore b/.gitignore index 0887050..31d573e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *.jl.mem /docs/Manifest.toml /docs/build/ +Manifest.toml diff --git a/Project.toml b/Project.toml index f7857f8..7b9a12b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,17 +1,17 @@ name = "LogitSamplers" uuid = "1b30fcfc-0ee9-4be2-9cfe-b2289b43e041" authors = ["murrellb and contributors"] -version = "1.0.0-DEV" +version = "1.1.0-DEV" [deps] NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] NNlib = "0.9" -Random = "1.11.0" -StatsBase = "0.34" +Random = "1" +Statistics = "1" julia = "1.9" [extras] diff --git a/README.md b/README.md index 6bf8c8f..5e54f0b 100644 --- a/README.md +++ b/README.md @@ -4,3 +4,36 @@ [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://MurrellGroup.github.io/LogitSamplers.jl/dev/) [![Build Status](https://github.com/MurrellGroup/LogitSamplers.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/MurrellGroup/LogitSamplers.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/MurrellGroup/LogitSamplers.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/MurrellGroup/LogitSamplers.jl) + +A Julia package for GPU-friendly sampling from logit distributions with various transformation methods commonly used in language models. + +## Usage + +The package provides a set of logit transforms to modify the distributions in the log domain. + +```julia +using LogitSamplers + +# Create a temperature transform +temperature = Temperature(1.5) + +# Create a top-p transform +top_p = Top_p(0.5) + +# Compose a function that first applies temperature, then top-p +transform = top_p ∘ temperature + +# Create a token index sampler function from the transform +sampler = logitsample ∘ transform + +# or equivalently: +sampler = logits -> logitsample(top_p(temperature(logits))) + +logits = randn(100) + +# Get token probabilities from the sampler +probs = softmax(transform(logits)) + +# Sample a logit index from the sampler +index = sampler(logits) +``` diff --git a/src/LogitSamplers.jl b/src/LogitSamplers.jl index 2524dc8..e549d77 100644 --- a/src/LogitSamplers.jl +++ b/src/LogitSamplers.jl @@ -1,9 +1,22 @@ module LogitSamplers -using NNlib, StatsBase +using NNlib: softmax +using Random +using Statistics: std -include("samplers.jl") +include("mask.jl") +include("sample.jl") +export logitsample + +include("transforms.jl") +export LogitTransform +export Temperature +export Top_pk, Top_p, Top_k +export Min_p +export Top_nσ + +include("deprecated.jl") export argmax_sampler, top_pk_sampler, min_p_sampler, top_nσ_sampler end diff --git a/src/deprecated.jl b/src/deprecated.jl new file mode 100644 index 0000000..2a04925 --- /dev/null +++ b/src/deprecated.jl @@ -0,0 +1,11 @@ +@deprecate argmax_sampler(logits; device=identity) Top_k(1)(device(logits)) +@deprecate argmax_sampler(; kwargs...) logits -> argmax_sampler(logits; kwargs...) + +@deprecate top_pk_sampler(logits; p = 0.5f0, k = 5, device = identity) Top_pk(p, k)(device(logits)) +@deprecate top_pk_sampler(; kwargs...) logits -> top_pk_sampler(logits; kwargs...) + +@deprecate min_p_sampler(logits; pbase = 0.5f0, device = identity) Min_p(pbase)(device(logits)) +@deprecate min_p_sampler(; kwargs...) logits -> min_p_sampler(logits; kwargs...) + +@deprecate top_nσ_sampler(logits; temperature = 1.0f0, n = 1.0f0, device = identity) Top_nσ(temperature, n)(device(logits)) +@deprecate top_nσ_sampler(; kwargs...) logits -> top_nσ_sampler(logits; kwargs...) diff --git a/src/mask.jl b/src/mask.jl new file mode 100644 index 0000000..a4ba8ed --- /dev/null +++ b/src/mask.jl @@ -0,0 +1,8 @@ +apply_mask(x::AbstractVector{T}, mask::AbstractVector{Bool}) where T<:AbstractFloat = + T(-Inf) * .!mask + x + +function create_mask(x::AbstractVector, indices::AbstractVector{Int}) + mask = similar(x, Bool) .= false + mask[indices] .= true + return mask +end \ No newline at end of file diff --git a/src/sample.jl b/src/sample.jl new file mode 100644 index 0000000..857bc5c --- /dev/null +++ b/src/sample.jl @@ -0,0 +1,15 @@ +""" + logitsample([rng], logits, [buffer=similar(logits)]) -> Int + +Sample an index from a logit distribution using the Gumbel argmax trick. + +Alternatively pass a buffer to avoid allocating a new array when creating +the random numbers. +""" +function logitsample(rng::AbstractRNG, x::AbstractVector{T}, u::AbstractVector{T}=similar(x)) where T<:AbstractFloat + length(x) == length(u) || throw(DimensionMismatch("Expected buffer of same length as logits")) + rand!(rng, u) + argmax(-log.(-log.(u)) + x) +end + +@inline logitsample(args...) = logitsample(Random.default_rng(), args...) diff --git a/src/samplers.jl b/src/samplers.jl deleted file mode 100644 index 3d69a8b..0000000 --- a/src/samplers.jl +++ /dev/null @@ -1,79 +0,0 @@ -#Gumbel argmax trick for GPU sampling -function samplelogits(logits::AbstractVector) - u = similar(logits) - rand!(u) - return argmax(-log.(-log.(u)) .+ logits) -end - -#To do: refactor into a combination of modified_softmax and sample. This way we can viz the result of the modified logits without having to sample. -#This won't be visible to the user. Any method that doesn't fit this interface can be implemented directly. - -function argmax_sampler(logits::AbstractVector; device = identity) - return argmax(device(logits)) -end - -""" - argmax_sampler(; device = identity) - -Returns a function that samples most likely token. -""" -argmax_sampler(; device = identity) = logits -> argmax_sampler(logits; device = device) - -function top_pk_sampler(logits::AbstractVector; p = 0.5f0, k = 5, device = identity) - probs = device(softmax(logits)) - perm = partialsortperm(probs, 1:k, rev=true) - sorted_probs = probs[perm] - cumsum_probs = cumsum(sorted_probs) - if cumsum_probs[1] > p - return perm[1] - else - cutoff = findlast(cumsum_probs .< p) - return sample(perm[1:cutoff], Weights(sorted_probs[1:cutoff])) - end -end - -""" - top_pk_sampler(; p = 0.5f0, k = 5, device = identity) - -Returns a function that samples from the most probable tokens. A combination of the top-k and top-p sampling methods, where you can sample from the top tokens with cumulative probability `p`, with a max number of tokens `k`. -""" -top_pk_sampler(;p = 0.5f0, k = 5, device = identity) = logits -> top_pk_sampler(logits; p, k, device) - -#https://arxiv.org/pdf/2407.01082 -function min_p_sampler(logits::AbstractVector{T}; pbase::T = 0.5f0, device = identity) where T - probs = device(softmax(logits)) - pmax = maximum(probs) - pscaled = pbase * pmax - mask = probs .>= pscaled - if !any(mask) - mask[argmax(probs)] = true - end - probs[.!mask] .= zero(T) - return sample(1:length(probs), Weights(probs)) -end - -""" - min_p_sampler(; pbase = 0.5f0, device = identity) - -Returns a function that samples from the most probable tokens using the min-p strategy. See: https://arxiv.org/pdf/2407.01082 -""" -min_p_sampler(; pbase = 0.5f0, device = identity) = logits -> min_p_sampler(logits; pbase, device) - -# https://arxiv.org/pdf/2411.07641 -function top_nσ_sampler(logits::AbstractVector{T}; temperature::T = 1.0f0, n::T = 1.0f0, device = identity) where T - scaled_logits = logits ./ temperature - M = maximum(scaled_logits) - σ = std(scaled_logits) - threshold = M - n * σ - mask = scaled_logits .>= threshold - scaled_logits[.!mask] .= -Inf - probs = device(softmax(scaled_logits)) - return sample(1:length(probs), Weights(probs)) -end - -""" - top_nσ_sampler(; temperature = 1.0f0, n = 1.0f0, device = identity) - -Returns a function that samples from the most probable tokens using the top-nσ strategy. See: https://arxiv.org/pdf/2411.07641 -""" -top_nσ_sampler(; temperature = 1.0f0, n = 1.0f0, device = identity) = logits -> top_nσ_sampler(logits; temperature, n, device) diff --git a/src/transforms.jl b/src/transforms.jl new file mode 100644 index 0000000..68931fc --- /dev/null +++ b/src/transforms.jl @@ -0,0 +1,74 @@ +abstract type LogitTransform <: Function end + +Base.show(io::IO, ::MIME"text/plain", t::LogitTransform) = show(io, t) + + +""" + Temperature(T) + +A logit transform that scales (divides) the logits by a temperature parameter. +""" +mutable struct Temperature{T<:Real} <: LogitTransform + T::T +end + +(t::Temperature)(logits::AbstractVector{T}) where T = logits / T(t.T) + + +""" + Top_pk(p, k) + +A logit transform that masks logits to only include tokens in the top `k` or the top `p` cumulative probability. +""" +mutable struct Top_pk{P<:Real,K<:Union{Integer,Nothing}} <: LogitTransform + p::P + k::K +end + +function (t::Top_pk)(logits::AbstractVector{T}) where T<:AbstractFloat + 0 < t.p <= 1 || throw(DomainError(t.p, "p must be in the interval (0, 1]")) + probs = softmax(logits) + sorted_probs = sort(probs, rev=true) + cutoff_p = maximum(sorted_probs[cumsum(sorted_probs) .>= t.p]; init=zero(T)) + cutoff_k = t.k isa Integer ? maximum(sorted_probs[t.k:t.k]) : zero(T) + return apply_mask(logits, probs .>= max(cutoff_p, cutoff_k)) +end + +Top_p(p) = Top_pk(p, nothing) +Top_k(k) = Top_pk(1, k) + + +""" + Min_p(pbase) + +A logit transform that samples from the most probable tokens using the min-p strategy. + +See: https://arxiv.org/pdf/2407.01082 +""" +mutable struct Min_p{T<:Real} <: LogitTransform + pbase::T +end + +function (t::Min_p)(logits::AbstractVector) + p = softmax(logits) + return apply_mask(logits, p .>= t.pbase * maximum(p)) +end + + +""" + Top_nσ(n) + +A logit transform that samples within `n` standard deviations of the maximum logit. + +Top-nσ is temperature-invariant, i.e. the candidate set does not change with temperature. + +See: https://arxiv.org/pdf/2411.07641 +""" +mutable struct Top_nσ{T<:Real} <: LogitTransform + n::T +end + +function (t::Top_nσ)(logits::AbstractVector) + M, σ = maximum(logits), std(logits) + return apply_mask(logits, logits .>= M - t.n * σ) +end diff --git a/test/runtests.jl b/test/runtests.jl index 0c21ea6..d0afcb1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,143 @@ using LogitSamplers using Test +using Random @testset "LogitSamplers.jl" begin - # Write your tests here. + + @testset "mask.jl" begin + n = 6 + logits = randn(n) + indices = [2, 4, 5] + mask = LogitSamplers.create_mask(logits, indices) + @test mask == Bool[0, 1, 0, 1, 1, 0] + @test all(mask[indices]) + @test !any(mask[setdiff(1:n, indices)]) + @test LogitSamplers.apply_mask(logits, mask)[mask] == logits[mask] + @test all(isinf, LogitSamplers.apply_mask(logits, mask)[.!mask]) + end + + @testset "sample.jl" begin + + @testset "Basic operations" begin + logits = randn(10) + buffer = similar(logits) + + @test logitsample(logits) isa Integer + @test logitsample(logits, buffer) isa Integer + @test logitsample(Random.MersenneTwister(123), logits) == logitsample(Random.MersenneTwister(123), logits, buffer) + end + + @testset "Statistical properties" begin + n_samples = 10000 + logits = log.([0.25, 0.75]) + counts = zeros(Int, 2) + for _ in 1:n_samples + idx = logitsample(logits) + counts[idx] += 1 + end + + probs = counts ./ n_samples + @test isapprox(probs[2] / probs[1], 3.0, rtol=0.1) + end + + @testset "Edge cases" begin + large_logits = [1000.0, -1000.0] + @test logitsample(large_logits) == 1 + + equal_logits = zeros(3) + samples = [logitsample(equal_logits) for _ in 1:1000] + @test length(unique(samples)) == 3 + end + + end + + @testset "transforms.jl" begin + + count_remaining(logits) = count(!isinf, logits) + + @testset "Temperature" begin + @test Temperature(1.0) isa Temperature + + @test repr(Temperature(1.0)) == "Temperature{Float64}(1.0)" + + logits = log.([0.1, 0.2, 0.3, 0.4]) + + @test count_remaining(Temperature(0.5)(logits)) == 4 + @test count_remaining(Temperature(1.0)(logits)) == 4 + @test count_remaining(Temperature(2.0)(logits)) == 4 + end + + @testset "Top_pk" begin + @test Top_pk(0.5, 3) isa Top_pk + @test Top_p(0.5) isa Top_pk + @test Top_k(3) isa Top_pk + + + logits = log.([0.1, 0.2, 0.3, 0.4]) + + @testset "Top_pk" begin + @test count_remaining(Top_pk(0.30, 3)(logits)) == 1 + @test count_remaining(Top_pk(0.50, 1)(logits)) == 1 + @test count_remaining(Top_pk(0.50, 3)(logits)) == 2 + @test count_remaining(Top_pk(0.80, 2)(logits)) == 2 + @test count_remaining(Top_pk(0.80, 3)(logits)) == 3 + @test count_remaining(Top_pk(0.95, 3)(logits)) == 3 + @test count_remaining(Top_pk(0.95, 4)(logits)) == 4 + @test count_remaining(Top_pk(1.00, 4)(logits)) == 4 + + @test_throws DomainError Top_pk(-1.0, 3)(logits) + @test_throws DomainError Top_pk(2.0, 3)(logits) + + @test_throws BoundsError Top_pk(0.5, 5)(logits) + @test_throws BoundsError Top_pk(0.5, 0)(logits) + end + + @testset "Top_p" begin + @test count_remaining(Top_p(0.30)(logits)) == 1 + @test count_remaining(Top_p(0.50)(logits)) == 2 + @test count_remaining(Top_p(0.80)(logits)) == 3 + @test count_remaining(Top_p(0.95)(logits)) == 4 + @test count_remaining(Top_p(1.00)(logits)) == 4 + + @test_throws DomainError Top_p(-1.0)(logits) + @test_throws DomainError Top_p(2.0)(logits) + end + + @testset "Top_k" begin + @test count_remaining(Top_k(1)(logits)) == 1 + @test count_remaining(Top_k(2)(logits)) == 2 + @test count_remaining(Top_k(3)(logits)) == 3 + @test count_remaining(Top_k(4)(logits)) == 4 + + @test_throws BoundsError Top_k(5)(logits) + @test_throws BoundsError Top_k(0)(logits) + end + end + + @testset "Min_p" begin + @test Min_p(0.5) isa Min_p + + logits = log.([0.1, 0.2, 0.3, 0.4]) + + @test count_remaining(Min_p(1.0)(logits)) == 1 + @test count_remaining(Min_p(0.8)(logits)) == 1 + @test count_remaining(Min_p(0.6)(logits)) == 2 + @test count_remaining(Min_p(0.4)(logits)) == 3 + @test count_remaining(Min_p(0.1)(logits)) == 4 + @test count_remaining(Min_p(0.0)(logits)) == 4 + end + + @testset "Top_nσ" begin + @test Top_nσ(1.0) isa Top_nσ + + logits = log.([0.1, 0.2, 0.3, 0.4]) + + @test count_remaining(Top_nσ(0.0)(logits)) == 1 + @test count_remaining(Top_nσ(1.0)(logits)) == 2 + @test count_remaining(Top_nσ(2.0)(logits)) == 3 + @test count_remaining(Top_nσ(3.0)(logits)) == 4 + end + + end + end