Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to support TNLRV3 fine-tuning #4639

Merged
merged 4 commits into from
Jul 30, 2020
Merged

Changes to support TNLRV3 fine-tuning #4639

merged 4 commits into from
Jul 30, 2020

Conversation

Tixxx
Copy link
Contributor

@Tixxx Tixxx commented Jul 28, 2020

Description: Describe your changes.
Changes to support tnlrv3 fine-tuning task

  1. added gradient op for ReduceLogSumExp, reference implementation taken from pytorch here
  2. fixed bug when passing fp16 to Cudnnreduce kernel. runtime type should be float
  3. added sanitization code in python frontend to remove redundant states to match with pytorch state dict.

Motivation and Context
To support tnlrv3 fine-tuning

added test
fixed type mismatch when calling cudnnreduce kernel
fixed python frontend to remove redundant states to match pytorch state dict
@Tixxx Tixxx requested a review from a team July 28, 2020 05:32
@Tixxx Tixxx added training issues related to ONNX Runtime training; typically submitted using template component:training-frontend labels Jul 28, 2020
@Tixxx Tixxx requested a review from SherlockNoMad July 29, 2020 05:01
Copy link
Contributor

@SherlockNoMad SherlockNoMad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR mostly looks good. Please address the comments and I think it's ready to go.

@Tixxx Tixxx requested a review from SherlockNoMad July 29, 2020 21:23
@Tixxx Tixxx merged commit f90a2d4 into master Jul 30, 2020
@Tixxx Tixxx deleted the tix/turing_finetuning branch July 30, 2020 02:18
thiagocrepaldi pushed a commit that referenced this pull request Aug 31, 2020
#4639 changed the default
behavior by removing optimizer state from state_dict/checkpoint APIs.
The reason for the previous change was to allow models trained on ORT to
be used for inference on PyTorch, which is an important feature.

Due to the change aforementioned, when resuming training from a checkpoint,
the optimizer would start with random weights, leading to a bad performance.
This behavior would also cause reproducibility issues, as the optimizer
wouldnt be able to resume from its previous state.

This PR adds a boolean flag to state_dict/save_xheckpoint API that
when True (default) it saves both model and optimizer state.
When False, only the model state is kept.
thiagocrepaldi pushed a commit that referenced this pull request Sep 1, 2020
#4639 changed the default
behavior by removing optimizer state from state_dict/checkpoint APIs.
The reason for the previous change was to allow models trained on ORT to
be used for inference on PyTorch, which is an important feature.

Due to the change aforementioned, when resuming training from a checkpoint,
the optimizer would start with random weights, leading to a bad performance.
This behavior would also cause reproducibility issues, as the optimizer
wouldnt be able to resume from its previous state.

This PR adds a boolean flag to state_dict/save_xheckpoint API that
when True (default) it saves both model and optimizer state.
When False, only the model state is kept.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants