diff --git a/variations/softmax_variations.py b/variations/softmax_variations.py index 6e9025ea7a..ac84d2eaaa 100644 --- a/variations/softmax_variations.py +++ b/variations/softmax_variations.py @@ -14,7 +14,7 @@ def forward(self, x): if self.subtract_max: max_x = x.max(dim=self.dim, keepdim=True).values x = x - max_x - e_x = torch.pow(math.e, x) + e_x = torch.pow(2.0, x) return e_x / e_x.sum(dim=self.dim, keepdim=True) # Softmax variation with learnable constant parameters for xmax and denominator