-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add LoRA adapter layer and Mixtral LoRA training #53
base: rocm_dev
Are you sure you want to change the base?
Conversation
4b5b3fc
to
31a02d4
Compare
31a02d4
to
e86131b
Compare
cc: @wenchenvincent |
@@ -137,6 +137,7 @@ def main(): | |||
|
|||
args = parser.parse_args() | |||
|
|||
mp.set_start_method('spawn') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this line here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
convert.py
has flaky behaviour in different environments and it is not 100% clear if it is an issue:
- in the script
- or Python/PyTorch behaviour change since the script was created
- or it is environment specific configuration
Sometimes start method is fork
and that causes crashes of convert.py
. To be on the safe side I explicitly set process start method to be spawn
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it use distributed checkpointing? I recalled that there was a similar with distributed checkpointing: #47
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zstreet87 Could you take a look at this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
convert.py
know nothing about checkpoint format - it is saver's responsibility. torch
format is used when the saver is saver_mcore.py
. I'm using saver_mcore.py
The issue with fork
start method appears for torch
checkpoint format but I think it is irrelevant to format.
"skip_bias_add": True, | ||
} | ||
COLUMN_PARALLEL_LAYERS = [ | ||
partial(TELinear, **LORA_LAYERS_DEFAULT_CONFIG, init_method=KAIMING_INIT_METHOD, parallel_mode=None, skip_weight_param_allocation=False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean Linear
instead of TELinear
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
te.pytorch.Linear
and especially torch.nn.Linear
has quite different constructor signatures whereas TELinear
constructor is aligned with ColumnParallelLinear, TEColumnParallelLinear etc and incapsulate those differences. To make the code more readable I explicitly used TELinear
. But essentially, in this case, it is a thin wrapper around torch.nn.Linear
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TELinear
is not a wrapper around torch.nn.Linear
but te.pytorch.Linear
.
In Megatron-LM, there are two alternative transformer implementation: local (using pytorch layers) and transformer-engine (using TE layers). ColumnParallelLinear
uses pytorch Linear layers and TEColumnParallelLinear
uses TE Linear layers. Usually, when a model is constructed with ColumnParallelLinear
, it often means that TE is not available. So here it is not appropriate to use TELinear
here.
Given that we cannot use torch.nn.Linear
directly, it seems that we will also need to create a thin wrapper around torch.nn.Linear
. And this also triggers another question from me: can we use the wrapper ColumnParallelLinear
again for the second lora layer here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I actually have a further question whether this would work for TP or not.
It seems that for a base layer like TEColumnParallelLinear
, we used two LoRA layers. The first layer is TELinear
and the second layer is TEColumnParallelLinear
. Does that mean the first layer will not be sliced to different GPUs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, indeed, I made a typo in the last sentence: TELinear
wraps te.pytorch.Linear
.
I implemented the Linear layer as a wrapper around ColumnParallelLinear
. The main difference is that the weight output size must be non-sharded. To achieve this, I copied some code from the ColumnParallelLinear
constructor.
Yes, your understanding is correct: all Linear
/TELinear
layers in the LoraAdapter
are not sliced. This is a deliberate decision:
- For TP, we sacrifice some memory to gain performance. Using a different approach would introduce approximately five additional inter-GPU calls per LoraAdapter.
- For EP+PP, which, as we observed, is the most performant training configuration for MoE models, no layers in the model are sliced.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification! While I think it might be Okay to sacrifice some memory to gain performance, I am kinda concerned about the functionality.
So in the case of TP, the weights of the first lora layer is not sliced and the weights of second lora layer is sliced across different GPUs with the same TP group. And the input data is the same across the GPUs within the same TP group. When we pass the activation of the first lora layer to the second lora layer, how do we make sure it is sliced properly? And how do we make sure that the gradient reduction and accumulation is done properly for the backward pass? In the scheme of this PR, we are doing DP for the first lora layer and TP for the second lora layer within a TP group. The combination of these two might be error prone, we will need to have tests to make sure this is implemented correctly.
@wenchenvincent : I see this that PR didn't go through CI. Do you have any idea ? |
The PR adds:
LoraAdapter
class to enable LoRA for modelsThe focus mostly is on expert parallelism.
Out of the scope of the PR (will be added later):
VocabParallelEmbedding
andTopKRouter
byLoraAdapter