From eabc3c2df0fb50122709190d8279f90d8d1c50b0 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 6 Sep 2024 16:39:49 +0000 Subject: [PATCH] lint --- jetstream_pt/cli.py | 3 +++ jetstream_pt/fetch_models.py | 2 +- jetstream_pt/third_party/llama/model_exportable.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/jetstream_pt/cli.py b/jetstream_pt/cli.py index fb1eb78..ce49d55 100644 --- a/jetstream_pt/cli.py +++ b/jetstream_pt/cli.py @@ -208,6 +208,8 @@ def interactive(): def main(): + """Main function.""" + def main_real(argv): """Entry point""" if len(argv) < 2: @@ -223,6 +225,7 @@ def main_real(argv): print( "Invalid arguments. please specify 'list', 'serve', or 'interactive'." ) + app.run(main_real) return 0 diff --git a/jetstream_pt/fetch_models.py b/jetstream_pt/fetch_models.py index a864442..c3e2312 100644 --- a/jetstream_pt/fetch_models.py +++ b/jetstream_pt/fetch_models.py @@ -40,7 +40,7 @@ class ModelInfo: num_layers: int # number of kv heads num_kv_heads: int - + head_dim: int n_reps: int # repeatition for GQA diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 368848b..7cebeb5 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -385,5 +385,5 @@ def transform(val, n_heads): value, self.params.n_kv_heads or self.params.n_heads ) res = super().convert_hf_weights(updated) - res['freqs_cis'] = self.freqs_cis + res["freqs_cis"] = self.freqs_cis return res