Skip to content

Commit

Permalink
Change mixtral moe key name for trt-llm (#9620)
Browse files Browse the repository at this point in the history
* fix minor import bug

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>

* change moe key values

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>

* add weight to the key

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>

---------

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
  • Loading branch information
oyilmaz-nvidia authored Jul 5, 2024
1 parent 18ecd41 commit b52229f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions nemo/export/trt_llm/converter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,14 +439,14 @@ def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals, storage_t
split_w3s = np.split(w3, split_factor, axis=1)

split_vals = [np.concatenate(item, axis=1) for item in zip(split_w3s, split_w1s)]
key = f'{layer_prefix}.mlp.experts_weight_1'
key = f'{layer_prefix}.mlp.fc.weight'
save_expert_split(split_vals, saved_dir, key, tp_rank, split_factor)

elif "experts.linear_fc2.weight" in key:
cat_dim = -1
val = np.concatenate(vals, axis=cat_dim)
split_vals = np.split(val, split_factor, axis=cat_dim)
key = f'{layer_prefix}.mlp.experts_weight_2'
key = f'{layer_prefix}.mlp.proj.weight'
save_expert_split(split_vals, saved_dir, key, tp_rank, split_factor)
else:
print(f"[WARNING] {key} not handled by converter")
Expand Down

0 comments on commit b52229f

Please sign in to comment.