Skip to content

Commit

Permalink
Logit transforms (#4)
Browse files Browse the repository at this point in the history
* Add logitsample

* Add logit transforms

* tweaks

* mutable types

* Add tests; Add Temperature, Top_pk, Top_p, Top_k; Add deprecated functions for compat

* Remove x86 and 1.9 testing

* Update README

* Add test for show method
  • Loading branch information
AntonOresten authored Nov 28, 2024
1 parent d8a397c commit 00a13bf
Show file tree
Hide file tree
Showing 11 changed files with 299 additions and 88 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@ jobs:
matrix:
version:
- '1.10'
- '1.9'
- 'pre'
os:
- ubuntu-latest
arch:
- x64
- x86
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
*.jl.mem
/docs/Manifest.toml
/docs/build/
Manifest.toml
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
name = "LogitSamplers"
uuid = "1b30fcfc-0ee9-4be2-9cfe-b2289b43e041"
authors = ["murrellb <murrellb@gmail.com> and contributors"]
version = "1.0.0-DEV"
version = "1.1.0-DEV"

[deps]
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
NNlib = "0.9"
Random = "1.11.0"
StatsBase = "0.34"
Random = "1"
Statistics = "1"
julia = "1.9"

[extras]
Expand Down
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,36 @@
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://MurrellGroup.github.io/LogitSamplers.jl/dev/)
[![Build Status](https://github.com/MurrellGroup/LogitSamplers.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/MurrellGroup/LogitSamplers.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/MurrellGroup/LogitSamplers.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/MurrellGroup/LogitSamplers.jl)

A Julia package for GPU-friendly sampling from logit distributions with various transformation methods commonly used in language models.

## Usage

The package provides a set of logit transforms to modify the distributions in the log domain.

```julia
using LogitSamplers

# Create a temperature transform
temperature = Temperature(1.5)

# Create a top-p transform
top_p = Top_p(0.5)

# Compose a function that first applies temperature, then top-p
transform = top_p temperature

# Create a token index sampler function from the transform
sampler = logitsample transform

# or equivalently:
sampler = logits -> logitsample(top_p(temperature(logits)))

logits = randn(100)

# Get token probabilities from the sampler
probs = softmax(transform(logits))

# Sample a logit index from the sampler
index = sampler(logits)
```
17 changes: 15 additions & 2 deletions src/LogitSamplers.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
module LogitSamplers

using NNlib, StatsBase
using NNlib: softmax
using Random
using Statistics: std

include("samplers.jl")
include("mask.jl")

include("sample.jl")
export logitsample

include("transforms.jl")
export LogitTransform
export Temperature
export Top_pk, Top_p, Top_k
export Min_p
export Top_nσ

include("deprecated.jl")
export argmax_sampler, top_pk_sampler, min_p_sampler, top_nσ_sampler

end
11 changes: 11 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
@deprecate argmax_sampler(logits; device=identity) Top_k(1)(device(logits))
@deprecate argmax_sampler(; kwargs...) logits -> argmax_sampler(logits; kwargs...)

@deprecate top_pk_sampler(logits; p = 0.5f0, k = 5, device = identity) Top_pk(p, k)(device(logits))
@deprecate top_pk_sampler(; kwargs...) logits -> top_pk_sampler(logits; kwargs...)

@deprecate min_p_sampler(logits; pbase = 0.5f0, device = identity) Min_p(pbase)(device(logits))
@deprecate min_p_sampler(; kwargs...) logits -> min_p_sampler(logits; kwargs...)

@deprecate top_nσ_sampler(logits; temperature = 1.0f0, n = 1.0f0, device = identity) Top_nσ(temperature, n)(device(logits))
@deprecate top_nσ_sampler(; kwargs...) logits -> top_nσ_sampler(logits; kwargs...)
8 changes: 8 additions & 0 deletions src/mask.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
apply_mask(x::AbstractVector{T}, mask::AbstractVector{Bool}) where T<:AbstractFloat =
T(-Inf) * .!mask + x

function create_mask(x::AbstractVector, indices::AbstractVector{Int})
mask = similar(x, Bool) .= false
mask[indices] .= true
return mask
end
15 changes: 15 additions & 0 deletions src/sample.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
logitsample([rng], logits, [buffer=similar(logits)]) -> Int
Sample an index from a logit distribution using the Gumbel argmax trick.
Alternatively pass a buffer to avoid allocating a new array when creating
the random numbers.
"""
function logitsample(rng::AbstractRNG, x::AbstractVector{T}, u::AbstractVector{T}=similar(x)) where T<:AbstractFloat
length(x) == length(u) || throw(DimensionMismatch("Expected buffer of same length as logits"))
rand!(rng, u)
argmax(-log.(-log.(u)) + x)
end

@inline logitsample(args...) = logitsample(Random.default_rng(), args...)
79 changes: 0 additions & 79 deletions src/samplers.jl

This file was deleted.

74 changes: 74 additions & 0 deletions src/transforms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
abstract type LogitTransform <: Function end

Base.show(io::IO, ::MIME"text/plain", t::LogitTransform) = show(io, t)


"""
Temperature(T)
A logit transform that scales (divides) the logits by a temperature parameter.
"""
mutable struct Temperature{T<:Real} <: LogitTransform
T::T
end

(t::Temperature)(logits::AbstractVector{T}) where T = logits / T(t.T)


"""
Top_pk(p, k)
A logit transform that masks logits to only include tokens in the top `k` or the top `p` cumulative probability.
"""
mutable struct Top_pk{P<:Real,K<:Union{Integer,Nothing}} <: LogitTransform
p::P
k::K
end

function (t::Top_pk)(logits::AbstractVector{T}) where T<:AbstractFloat
0 < t.p <= 1 || throw(DomainError(t.p, "p must be in the interval (0, 1]"))
probs = softmax(logits)
sorted_probs = sort(probs, rev=true)
cutoff_p = maximum(sorted_probs[cumsum(sorted_probs) .>= t.p]; init=zero(T))
cutoff_k = t.k isa Integer ? maximum(sorted_probs[t.k:t.k]) : zero(T)
return apply_mask(logits, probs .>= max(cutoff_p, cutoff_k))
end

Top_p(p) = Top_pk(p, nothing)
Top_k(k) = Top_pk(1, k)


"""
Min_p(pbase)
A logit transform that samples from the most probable tokens using the min-p strategy.
See: https://arxiv.org/pdf/2407.01082
"""
mutable struct Min_p{T<:Real} <: LogitTransform
pbase::T
end

function (t::Min_p)(logits::AbstractVector)
p = softmax(logits)
return apply_mask(logits, p .>= t.pbase * maximum(p))
end


"""
Top_nσ(n)
A logit transform that samples within `n` standard deviations of the maximum logit.
Top-nσ is temperature-invariant, i.e. the candidate set does not change with temperature.
See: https://arxiv.org/pdf/2411.07641
"""
mutable struct Top_nσ{T<:Real} <: LogitTransform
n::T
end

function (t::Top_nσ)(logits::AbstractVector)
M, σ = maximum(logits), std(logits)
return apply_mask(logits, logits .>= M - t.n * σ)
end
Loading

0 comments on commit 00a13bf

Please sign in to comment.