Skip to content

Commit

Permalink
Gumbel argmax for GPU sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
murrellb committed Nov 27, 2024
1 parent ee77d4b commit d8a397c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ version = "1.0.0-DEV"

[deps]
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
NNlib = "0.9"
Random = "1.11.0"
StatsBase = "0.34"
julia = "1.9"

Expand Down
7 changes: 7 additions & 0 deletions src/samplers.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
#Gumbel argmax trick for GPU sampling
function samplelogits(logits::AbstractVector)
u = similar(logits)
rand!(u)
return argmax(-log.(-log.(u)) .+ logits)
end

#To do: refactor into a combination of modified_softmax and sample. This way we can viz the result of the modified logits without having to sample.
#This won't be visible to the user. Any method that doesn't fit this interface can be implemented directly.

Expand Down

0 comments on commit d8a397c

Please sign in to comment.