Skip to content

Commit

Permalink
Fix weights not updating during training (#6068)
Browse files Browse the repository at this point in the history
Fix weights not updating during training (#6039)

Summary:
Loaded weights were not updating in ET because the returned weights were clones, not the original weights.

Pull Request resolved: #6039

Test Plan:
```
> buck2 run fbcode//executorch/examples/llm_pte_finetuning:runner -- --cfg=fbcode/executorch/examples/llm_pte_finetuning/phi3_config.yaml --model_file=phi3_mini_lora.pte
Evaluating the model before training...
100%|██████████████████████████████████████████████████████████████████████████████████████| 3/3 [31:23<00:00, 627.98s/it]
Eval loss:  tensor(2.3778)
grad_start:`1`
param_start:`129`
100%|██████████████████████████████████████████████████████████████████████████████████████| 5/5 [52:29<00:00, 629.84s/it]
Losses:  [2.7152762413024902, 0.7890686988830566, 2.249271869659424, 1.4777560234069824, 0.8378427624702454]
100%|██████████████████████████████████████████████████████████████████████████████████████| 3/3 [30:35<00:00, 611.90s/it]
Eval loss:  tensor(0.8464)
```

Reviewed By: dpalmasan

Differential Revision: D64084552

Pulled By: dvorjackz

fbshipit-source-id: 9d478dda02f7bcaa5964d83d257d0db5bfe9feab
(cherry picked from commit 867c96a)

Co-authored-by: Jack Zhang <dvorjackz@gmail.com>
  • Loading branch information
pytorchbot and dvorjackz authored Oct 9, 2024
1 parent c1b10a7 commit 4f4dc0b
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion examples/llm_pte_finetuning/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def main() -> None:
else:
labels = F.pad(labels, (0, max_seq_len - labels_size), value=0)

out = et_mod.forward((tokens, labels))
# Do not clone outputs, since we want the original weights to be returned
# for us to update with the gradients in-place.
# See https://github.com/pytorch/executorch/blob/main/extension/pybindings/pybindings.cpp#L736
# for more info.
out = et_mod.forward((tokens, labels), clone_outputs=False) # pyre-ignore

loss = out[0]
losses.append(loss.item())
Expand Down

0 comments on commit 4f4dc0b

Please sign in to comment.