From f0ec33e8b2b3364f697f09c757d9b210f3af46a3 Mon Sep 17 00:00:00 2001 From: Gregory Kielian Date: Mon, 14 Oct 2024 08:04:51 +0000 Subject: [PATCH] Add optional input clamping for ConSmaxV2 This prevents overflow by adding a saturation without having to resort to piecewise functions. --- gpt_conf.py | 2 ++ train.py | 2 ++ variations/softmax_variations.py | 8 ++++++++ 3 files changed, 12 insertions(+) diff --git a/gpt_conf.py b/gpt_conf.py index 2a20206014..1d91ff3063 100644 --- a/gpt_conf.py +++ b/gpt_conf.py @@ -91,6 +91,8 @@ class GPTConfig: ## ConSmaxV2 Special Options consmax_per_head: bool = True # different beta gamma per head + consmax_v2_clamping: bool = True + consmax_v2_clamp_value: float = 80.0 ## SaturatingConSmax Special options (otherwise same as ConSmax) consmax_saturation: float = 11.0 # for SaturatingConSmax saturation point diff --git a/train.py b/train.py index f24c0bf73e..103495fd3c 100644 --- a/train.py +++ b/train.py @@ -320,6 +320,8 @@ def parse_args(): ### Special Options for ConSmaxV2 model_group.add_argument("--consmax_per_head", default=True, action=argparse.BooleanOptionalAction) + model_group.add_argument("--consmax_v2_clamping", default=False, action=argparse.BooleanOptionalAction) + model_group.add_argument("--consmax_v2_clamp_value", type=float, default=80.0, help="maximum value to clamp inputs") ### Special Options for SaturatingConSmax model_group.add_argument("--consmax_saturation", type=float, default=11.0, help="point where we transition from consmax to linear saturatingconsmax, defaults to 11 to approximate e^x sat for fp16") diff --git a/variations/softmax_variations.py b/variations/softmax_variations.py index 51098214e3..84a400cc05 100644 --- a/variations/softmax_variations.py +++ b/variations/softmax_variations.py @@ -82,6 +82,10 @@ def __init__(self, config, dim=-1): self.beta = self.beta_init * self.beta_factor self.gamma = self.beta_init * self.gamma_factor + # Set optional clamping (on by default) + self.clamp_inputs = config.consmax_v2_clamping + self.clamp_value = config.consmax_v2_clamp_value + # Set the base of the exponent if config.consmax_use_euler_base: self.consmax_base = math.e @@ -93,7 +97,11 @@ def forward(self, x): self.gamma = self.gamma_factor * self.gamma_init x_adj = x - self.beta + if self.clamp_inputs: + x_adj[x_adj > self.clamp_value] = self.clamp_value + e_x = torch.pow(self.consmax_base, x_adj) + result = e_x / self.gamma if self.training and self.softmax_io_logging and self.iter_num % self.softmax_io_log_interval == 0: