Skip to content

Commit

Permalink
Add optional input clamping for ConSmaxV2
Browse files Browse the repository at this point in the history
This prevents overflow by adding a saturation without having to resort
to piecewise functions.
  • Loading branch information
gkielian committed Oct 14, 2024
1 parent 1405df4 commit f0ec33e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 0 deletions.
2 changes: 2 additions & 0 deletions gpt_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class GPTConfig:

## ConSmaxV2 Special Options
consmax_per_head: bool = True # different beta gamma per head
consmax_v2_clamping: bool = True
consmax_v2_clamp_value: float = 80.0

## SaturatingConSmax Special options (otherwise same as ConSmax)
consmax_saturation: float = 11.0 # for SaturatingConSmax saturation point
Expand Down
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ def parse_args():

### Special Options for ConSmaxV2
model_group.add_argument("--consmax_per_head", default=True, action=argparse.BooleanOptionalAction)
model_group.add_argument("--consmax_v2_clamping", default=False, action=argparse.BooleanOptionalAction)
model_group.add_argument("--consmax_v2_clamp_value", type=float, default=80.0, help="maximum value to clamp inputs")

### Special Options for SaturatingConSmax
model_group.add_argument("--consmax_saturation", type=float, default=11.0, help="point where we transition from consmax to linear saturatingconsmax, defaults to 11 to approximate e^x sat for fp16")
Expand Down
8 changes: 8 additions & 0 deletions variations/softmax_variations.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def __init__(self, config, dim=-1):
self.beta = self.beta_init * self.beta_factor
self.gamma = self.beta_init * self.gamma_factor

# Set optional clamping (on by default)
self.clamp_inputs = config.consmax_v2_clamping
self.clamp_value = config.consmax_v2_clamp_value

# Set the base of the exponent
if config.consmax_use_euler_base:
self.consmax_base = math.e
Expand All @@ -93,7 +97,11 @@ def forward(self, x):
self.gamma = self.gamma_factor * self.gamma_init

x_adj = x - self.beta
if self.clamp_inputs:
x_adj[x_adj > self.clamp_value] = self.clamp_value

e_x = torch.pow(self.consmax_base, x_adj)

result = e_x / self.gamma

if self.training and self.softmax_io_logging and self.iter_num % self.softmax_io_log_interval == 0:
Expand Down

0 comments on commit f0ec33e

Please sign in to comment.