From c3370686eaf7601813f20c5999ae3ac85bb8d85e Mon Sep 17 00:00:00 2001 From: Gregory Kielian Date: Tue, 9 Apr 2024 20:35:27 -0700 Subject: [PATCH] Add Era of 1.58 bit LLMs BitLinear implementation Adding MIT Licensed ternary implementation of BitLinear: https://huggingface.co/1bitLLM/bitnet_b1_58-large/blob/main/utils_quant.py Ternary BitLinear Arxiv Paper Link: https://arxiv.org/abs/2402.17764 --- explorations/linear_sweep.json | 2 +- model.py | 2 +- train.py | 1 + variations/linear_variations.py | 45 +++++++++++++++++++++++++++++++++ 4 files changed, 48 insertions(+), 2 deletions(-) diff --git a/explorations/linear_sweep.json b/explorations/linear_sweep.json index 4e106c3c3d..db68c4fbaa 100644 --- a/explorations/linear_sweep.json +++ b/explorations/linear_sweep.json @@ -11,7 +11,7 @@ "device": ["cuda"], "dtype": ["float16"], "dataset": ["shakespeare_char"], - "linear_variant": ["bitlinear", "bitlinear_optimized", "linear"], + "linear_variant": ["bitlinear_1p58", "bitlinear", "bitlinear_optimized", "linear"], "compile": [true], "softmax_variant_attn": ["softmax", "polymax"], "tensorboard_run_name": ["linear_variation_sweep"] diff --git a/model.py b/model.py index a1883c9536..c3e14f4524 100644 --- a/model.py +++ b/model.py @@ -21,7 +21,7 @@ from variations.normalization_variations import LayerNorm, RMSNorm from variations.position_encoding_variations import RotaryEmbedding, ShortRope, SymmetricalOverlapAngularPositions from variations.activation_variations import SquaredReLU, activation_dictionary -from variations.linear_variations import BitLinear, BitLinearOptimized, linear_dictionary +from variations.linear_variations import BitLinear1p58, BitLinear, BitLinearOptimized, linear_dictionary def create_shared_param_group(layer_type, config): shared_size = None diff --git a/train.py b/train.py index 4a6a2a24e0..d01e922593 100644 --- a/train.py +++ b/train.py @@ -102,6 +102,7 @@ def parse_args(): choices=[ "linear", "bitlinear", + "bitlinear_1p58", "bitlinear_optimized", ], ) diff --git a/variations/linear_variations.py b/variations/linear_variations.py index 328f21f391..141ccfec3e 100644 --- a/variations/linear_variations.py +++ b/variations/linear_variations.py @@ -2,6 +2,50 @@ import torch.nn as nn import math +class BitLinear1p58(nn.Linear): + """ BitLinear from Era of 1.58 LLMs Paper + Source: https://huggingface.co/1bitLLM/bitnet_b1_58-large/blob/main/utils_quant.py + Source License: MIT + Paper Link: https://arxiv.org/abs/2402.17764 + """ + + def __init__(self, in_features, out_features, bias=True, num_groups=1): + super().__init__(in_features, out_features, bias) + + """ + RMSNorm is placed outside BitLinear + """ + weight_bits=1 + input_bits=8 + self.weight_bits = weight_bits + self.input_bits = input_bits + + def forward(self, x): + + quant_input = x + (self.activation_quant(x, self.input_bits) - x).detach() + quant_weight = self.weight + (self.weight_quant(self.weight, self.weight_bits) - self.weight).detach() + + out = nn.functional.linear(quant_input, quant_weight) + if not self.bias is None: + out += self.bias.view(1, -1).expand_as(out) + + return out + + def weight_quant(self, weight, num_bits=1): + dtype = weight.dtype + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) / s + return result.type(dtype) + + def activation_quant(self, x, num_bits=8): + dtype = x.dtype + x = x.float() + Qn = -2 ** (num_bits - 1) + Qp = 2 ** (num_bits - 1) - 1 + s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + result = (x * s).round().clamp(Qn, Qp) / s + return result.type(dtype) class BitLinear(nn.Linear): """PyTorch BitLinear Layer @@ -175,4 +219,5 @@ def forward(self, input): "linear": nn.Linear, "bitlinear": BitLinear, "bitlinear_optimized": BitLinearOptimized, + "bitlinear_1p58": BitLinear1p58, }