Skip to content

Commit

Permalink
fix: pass rope_theta argument when initializing LlamaLikeBlock for mo…
Browse files Browse the repository at this point in the history
…dels like qwen2, mistral, etc. (#568)

Co-authored-by: Shuai Xie <shuaixie@jd.com>
  • Loading branch information
Shuai-Xie and Shuai Xie authored Aug 4, 2024
1 parent 2e2635b commit 202b967
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 0 deletions.
1 change: 1 addition & 0 deletions awq/models/aquila.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def fuse_transformer(self):
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
)
)

Expand Down
1 change: 1 addition & 0 deletions awq/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def fuse_transformer(self):
norm_2=norm_2,
dev=device,
max_seq_len=max_seq_len,
rope_theta=self.model.config.rope_theta,
)
)

Expand Down
1 change: 1 addition & 0 deletions awq/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def fuse_transformer(self):
norm_2=norm_2,
dev=device,
max_seq_len=max_seq_len,
rope_theta=self.model.config.rope_theta,
)
)

Expand Down
1 change: 1 addition & 0 deletions awq/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def fuse_transformer(self):
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
)
)

Expand Down
1 change: 1 addition & 0 deletions awq/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def fuse_transformer(self):
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
)
)

Expand Down
1 change: 1 addition & 0 deletions awq/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def fuse_transformer(self):
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
)
)

Expand Down

0 comments on commit 202b967

Please sign in to comment.