Skip to content

Commit 8cc20fb

Browse files
authored
Merge pull request #509 from SMAntony/main
Added quantization support for huggingface models
2 parents dce3186 + 3cd833f commit 8cc20fb

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/vanna/hf/hf.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77
class Hf(VannaBase):
88
def __init__(self, config=None):
9-
model_name = self.config.get(
10-
"model_name", None
11-
) # e.g. meta-llama/Meta-Llama-3-8B-Instruct
12-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
9+
model_name_or_path = self.config.get(
10+
"model_name_or_path", None
11+
) # e.g. meta-llama/Meta-Llama-3-8B-Instruct or local path to the model checkpoint files
12+
# list of quantization methods supported by transformers package: https://huggingface.co/docs/transformers/main/en/quantization/overview
13+
quantization_config = self.config.get("quantization_config", None)
14+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
1315
self.model = AutoModelForCausalLM.from_pretrained(
14-
model_name,
15-
torch_dtype="auto",
16+
model_name_or_path,
17+
quantization_config=quantization_config,
1618
device_map="auto",
1719
)
1820

0 commit comments

Comments
 (0)