Skip to content

Latest commit

 

History

History
39 lines (26 loc) · 1.41 KB

README.md

File metadata and controls

39 lines (26 loc) · 1.41 KB

LogitSamplers

Stable Dev Build Status Coverage

A Julia package for GPU-friendly sampling from logit distributions with various transformation methods commonly used in language models.

Usage

The package provides a set of logit transforms to modify the distributions in the log domain.

using LogitSamplers

# Create a temperature transform
temperature = Temperature(1.5)

# Create a top-p transform
top_p = Top_p(0.5)

# Compose a function that first applies temperature, then top-p
transform = top_p  temperature

# Create a token index sampler function from the transform
sampler = logitsample  transform

# or equivalently:
sampler = logits -> logitsample(top_p(temperature(logits)))

logits = randn(100)

# Get token probabilities with the transformed logits
probs = softmax(transform(logits))

# Sample a logit index from the sampler
index = sampler(logits)