Skip to content

Commit

Permalink
enable bnb 4/8bit inference (#2419)
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Jun 22, 2023
1 parent bd1f52f commit f13dbd8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
4 changes: 4 additions & 0 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def load_test_model(opt, model_path=None):

model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])

model_opt.quant_layers = opt.quant_layers
model_opt.quant_type = opt.quant_type

ArgumentParser.update_model_opts(model_opt)
ArgumentParser.validate_model_opts(model_opt)
vocabs = dict_to_vocabs(checkpoint["vocab"])
Expand All @@ -118,6 +121,7 @@ def load_test_model(opt, model_path=None):
else:
device = torch.device("cpu")

logger.info("Loading data into the model")
if "model" in checkpoint.keys():
# weights are in the .pt file
model.load_state_dict(
Expand Down
42 changes: 24 additions & 18 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,24 +1133,6 @@ def _add_train_general_opts(parser):
help="rule of thumb: same value as in main model",
)

group.add(
"--quant_layers",
"-quant_layers",
default=[],
nargs="+",
type=str,
help="list of layers to be compressed in 4/8bit.",
)

group.add(
"--quant_type",
"-quant_type",
default="bnb_8bit",
choices=["bnb_8bit", "bnb_FP4", "bnb_NF4"],
type=str,
help="Type of compression.",
)

_add_reproducibility_opts(parser)

# Init options
Expand Down Expand Up @@ -1533,6 +1515,27 @@ def _add_train_dynamic_data(parser):
)


def _add_quant_opts(parser):
group = parser.add_argument_group("Quant options")
group.add(
"--quant_layers",
"-quant_layers",
default=[],
nargs="+",
type=str,
help="list of layers to be compressed in 4/8bit.",
)

group.add(
"--quant_type",
"-quant_type",
default="bnb_8bit",
choices=["bnb_8bit", "bnb_FP4", "bnb_NF4"],
type=str,
help="Type of compression.",
)


def train_opts(parser):
"""All options used in train."""
# options relate to data preprare
Expand All @@ -1541,6 +1544,7 @@ def train_opts(parser):
model_opts(parser)
_add_train_general_opts(parser)
_add_train_dynamic_data(parser)
_add_quant_opts(parser)


def _add_decoding_opts(parser):
Expand Down Expand Up @@ -1828,6 +1832,8 @@ def translate_opts(parser, dynamic=False):
# Adding options related to Transforms
_add_dynamic_transform_opts(parser)

_add_quant_opts(parser)


# Copyright 2016 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
Expand Down

0 comments on commit f13dbd8

Please sign in to comment.