Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA for MoE Layer #9396

Merged
merged 5 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,16 @@

import torch
import torch.nn.functional as F
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.core.tensor_parallel import ColumnParallelLinear
from megatron.core.transformer.attention import SelfAttention
from megatron.core.transformer.custom_layers.transformer_engine import (
SplitAlongDim,
TEColumnParallelLinear,
TELayerNormColumnParallelLinear,
)
from megatron.core.transformer.custom_layers.transformer_engine import SplitAlongDim
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.moe.experts import SequentialMLP
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.utils import make_viewless_tensor

Expand All @@ -37,6 +34,8 @@
LoraDenseAttentionAdapterConfig,
LoraHto4HAdapterConfig,
LoraKQVAdapterConfig,
LoraMoe4HtoHAdapterConfig,
LoraMoeHto4HAdapterConfig,
LoraUnfusedHto4HAdapterConfig,
LoraUnfusedKQVAdapterConfig,
MLPInfusedAdapterConfig,
Expand Down Expand Up @@ -281,13 +280,15 @@ def forward(
class MCoreMLPMixin(MLP, MCoreAdapterModuleMixin):
def mcore_register_adapters(self):
"""
Setup NeMo IA3 adapter to this MCore layer.
Setup NeMo IA3 and LoRA adapter to this MCore layer.
"""
self.set_accepted_adapter_types(
[
LoraUnfusedHto4HAdapterConfig._target_,
LoraHto4HAdapterConfig._target_,
Lora4HtoHAdapterConfig._target_,
LoraMoeHto4HAdapterConfig._target_,
LoraMoe4HtoHAdapterConfig._target_,
MLPInfusedAdapterConfig._target_,
]
) # only self attn (packed qkv) for now
Expand All @@ -302,9 +303,12 @@ def mcore_register_adapters(self):
# overlap is used.
self.linear_fc1.return_layernorm_output_gathered = True

def forward(self, hidden_states):
def forward(self, hidden_states, expert_idx=None):
# [s, b, 4 * h/p]
if self.linear_fc1.te_return_bias:
if isinstance(self.linear_fc1, ColumnParallelLinear):
layernorm_output = hidden_states
intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)
elif self.linear_fc1.te_return_bias:
intermediate_parallel, bias_parallel, layernorm_output = self.linear_fc1(hidden_states)
else:
# bias_parallel is None
Expand All @@ -315,15 +319,19 @@ def forward(self, hidden_states):
lora_adapter = None
lora_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER)
lora_unfused_fc1_adapter = self.get_adapter_module(AdapterName.LORA_UNFUSED_Hto4H_ADAPTER)
lora_moe_fc1_adapter = self.get_adapter_module(AdapterName.LORA_MOE_Hto4H_ADAPTER)
if lora_fc1_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']:
lora_adapter = lora_fc1_adapter
if lora_unfused_fc1_adapter and self.adapter_cfg[AdapterName.LORA_UNFUSED_Hto4H_ADAPTER]['enabled']:
assert lora_adapter is None, "Expected only one of LORA_Hto4H_ADAPTER or LORA_UNFUSED_Hto4H_ADAPTER"
lora_adapter = lora_unfused_fc1_adapter

lora_output = 0
if lora_adapter:
lora_output = lora_adapter(layernorm_output)
intermediate_parallel = intermediate_parallel + lora_output
elif lora_moe_fc1_adapter and self.adapter_cfg[AdapterName.LORA_MOE_Hto4H_ADAPTER]['enabled']:
lora_output = lora_moe_fc1_adapter(layernorm_output, expert_idx)
intermediate_parallel = intermediate_parallel + lora_output

if self.config.bias_activation_fusion:
if self.activation_func == F.gelu:
Expand Down Expand Up @@ -363,14 +371,51 @@ def glu(x):

# LoRA logic
if self.is_adapter_available():
lora_linear_fc2_adapter = self.get_adapter_module(AdapterName.LORA_4HtoH_ADAPTER)
if lora_linear_fc2_adapter and self.adapter_cfg[AdapterName.LORA_4HtoH_ADAPTER]['enabled']:
lora_output = lora_linear_fc2_adapter(intermediate_parallel)
output = output + lora_output
lora_fc2_adapter = self.get_adapter_module(AdapterName.LORA_4HtoH_ADAPTER)
lora_moe_fc2_adapter = self.get_adapter_module(AdapterName.LORA_MOE_4HtoH_ADAPTER)

lora_output = 0
if lora_fc2_adapter and self.adapter_cfg[AdapterName.LORA_4HtoH_ADAPTER]['enabled']:
lora_output = lora_fc2_adapter(intermediate_parallel)
elif lora_moe_fc2_adapter and self.adapter_cfg[AdapterName.LORA_MOE_4HtoH_ADAPTER]['enabled']:
lora_output = lora_moe_fc2_adapter(intermediate_parallel, expert_idx)

output = output + lora_output

return output, output_bias


class MCoreSequentialMLPMixin(SequentialMLP, MCoreAdapterModuleMixin):
def mcore_register_adapters(self):
"""
We don't want the SequentialMLP layer to take any adapters. We only want to override the forward() behavior
"""
pass

def forward(self, permuted_local_hidden_states, tokens_per_expert):
output_local = torch.zeros_like(permuted_local_hidden_states)
output_bias_local = None
if self.add_bias:
output_bias_local = torch.zeros_like(permuted_local_hidden_states)

cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
# Insert zero at the begining for offset index's convenience
zero_tensor = torch.zeros(1, dtype=torch.long)
akoumpa marked this conversation as resolved.
Show resolved Hide resolved
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
for expert_num, expert in enumerate(self.local_experts):
start = cumsum_num_tokens[expert_num]
end = cumsum_num_tokens[expert_num + 1]
hidden = permuted_local_hidden_states[start:end]
output, output_bias = expert(hidden, expert_num) # expert: MLP

output_local[start:end] = output
if self.add_bias:
output_bias = output_bias.expand_as(output)
output_bias_local[start:end, :] = output_bias

return output_local, output_bias_local


class MCoreGPTEmbeddingMixin(LanguageModelEmbedding, MCoreAdapterModuleMixin):
def mcore_register_adapters(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class AdapterName(str, enum.Enum):
LORA_Hto4H_ADAPTER = "lora_hto4h_adapter"
LORA_UNFUSED_Hto4H_ADAPTER = "lora_unfused_hto4h_adapter"
LORA_4HtoH_ADAPTER = "lora_4htoh_adapter"
LORA_MOE_Hto4H_ADAPTER = "lora_moe_hto4h_adapter"
LORA_MOE_4HtoH_ADAPTER = "lora_moe_4htoh_adapter"
MULTIMODAL_PROJECTOR_ADAPTER = "mm_projector_adapter"
PARALLEL_LINEAR_ADAPTER = "parallel_linear_adapter"

Expand Down Expand Up @@ -611,6 +613,80 @@ class LoraUnfusedKQVAdapterConfig(AdapterConfig):
_target_: str = "{0}.{1}".format(LoraUnfusedKQVAdapter.__module__, LoraUnfusedKQVAdapter.__name__)


class LoraMoeAdapter(nn.Module, AdapterModuleUtil):
def __init__(
self,
num_moe_experts: int,
in_features: int,
out_features: int,
dim: int,
activation: str = 'identity',
norm_position: Optional[str] = None,
norm_type: Optional[str] = None,
column_init_method: str = 'xavier',
row_init_method: str = 'zero',
gather_output: bool = False,
input_is_parallel: bool = False,
dropout: float = 0.0,
model_parallel_config: Optional[ModelParallelConfig] = None,
alpha: float | None = None,
dropout_position: str = 'post',
a2a_experimental: bool = False,
**kwargs,
):
super().__init__()

self.num_moe_experts = num_moe_experts
adapter_args = {
"in_features": in_features,
"out_features": out_features,
"dim": dim,
"activation": activation,
"norm_position": norm_position,
"norm_type": norm_type,
"column_init_method": column_init_method,
"row_init_method": row_init_method,
"gather_output": gather_output,
"input_is_parallel": input_is_parallel,
"dropout": dropout,
"model_parallel_config": model_parallel_config,
"alpha": alpha,
"dropout_position": dropout_position,
"a2a_experimental": a2a_experimental,
}
self.expert_adapters = nn.ModuleList()
for i in range(num_moe_experts):
self.expert_adapters.append(ParallelLinearAdapter(**adapter_args))

def forward(self, x, expert_idx):
return self.expert_adapters[expert_idx](x)


@dataclass
class LoraMoeHto4HAdapterConfig(AdapterConfig):
num_moe_experts: int
in_features: int
out_features: int
dim: int
activation: str = 'identity'
norm_position: Optional[str] = None
norm_type: Optional[str] = None
column_init_method: str = 'xavier'
row_init_method: str = 'zero'
gather_output: bool = False
input_is_parallel: bool = False
dropout: float = 0.0
dropout_position: str = 'post'
alpha: float | None = None
a2a_experimental: bool = False
_target_: str = "{0}.{1}".format(LoraMoeAdapter.__module__, LoraMoeAdapter.__name__)


@dataclass
class LoraMoe4HtoHAdapterConfig(LoraMoeHto4HAdapterConfig):
input_is_parallel: bool = True


class PromptEncoderAdapter(nn.Module, AdapterModuleUtil):
"""
The Tensor Parallel MLP prompt encoder network that is used to generate the virtual
Expand Down Expand Up @@ -690,20 +766,14 @@ def set_inference_table(self, prompt_representation: torch.Tensor):
self.is_inference_ready = True
return True

def clear_inference_table(
self,
):
def clear_inference_table(self):
self.inference_table.fill_(0.0)
self.is_inference_ready = False

def get_inference_table(
self,
):
def get_inference_table(self):
return self.inference_table.data

def inner_forward(
self,
):
def inner_forward(self):
input_embeds = self.embedding(self.indices).unsqueeze(0)
intermediate_parallel, bias_parallel = self.first(input_embeds)
intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel)
Expand Down
40 changes: 35 additions & 5 deletions nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
MCoreGPTEmbeddingMixin,
MCoreMLPMixin,
MCoreSelfAttentionMixin,
MCoreSequentialMLPMixin,
MCoreTransformerLayerMixin,
)
except (ImportError, ModuleNotFoundError):
Expand All @@ -36,6 +37,8 @@
LoraHto4HAdapterConfig,
LoraKQVAdapterConfig,
LoraKQVAdapterWeightTyingConfig,
LoraMoe4HtoHAdapterConfig,
LoraMoeHto4HAdapterConfig,
LoraUnfusedHto4HAdapterConfig,
LoraUnfusedKQVAdapterConfig,
MLPInfusedAdapterConfig,
Expand Down Expand Up @@ -169,7 +172,10 @@ def __init__(self, cfg):

elif module == PEFT_MODULE_MAP["hto4h_module"]:
hto4h_projection_size = cfg.ffn_hidden_size * 2 if fast_glu_activation else cfg.ffn_hidden_size
if lora_cfg.get("variant", "nemo") == "canonical":
if cfg.get('num_moe_experts', None):
_adapter_name = AdapterName.LORA_MOE_Hto4H_ADAPTER
_adapter_cfg_cls = LoraMoeHto4HAdapterConfig
elif lora_cfg.get("variant", "nemo") == "canonical":
_adapter_name = AdapterName.LORA_UNFUSED_Hto4H_ADAPTER
_adapter_cfg_cls = LoraUnfusedHto4HAdapterConfig
else:
Expand All @@ -180,13 +186,35 @@ def __init__(self, cfg):
cfg, lora_cfg, cfg.hidden_size, hto4h_projection_size, _adapter_cfg_cls
)
name_key_to_cfg[_adapter_name] = adapter_cfg
name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)]
if _adapter_name == AdapterName.LORA_MOE_Hto4H_ADAPTER:
name_key_to_mcore_mixins[_adapter_name] = [("mlp.experts", MCoreSequentialMLPMixin)]
for i in range(int(cfg.num_moe_experts)):
name_key_to_mcore_mixins[_adapter_name].append(
(f"mlp.experts.local_experts.{i}", MCoreMLPMixin)
)
else:
name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)]

elif module == PEFT_MODULE_MAP["4htoh_module"]:
if cfg.get('num_moe_experts', None):
_adapter_name = AdapterName.LORA_MOE_4HtoH_ADAPTER
_adapter_cfg_cls = LoraMoe4HtoHAdapterConfig
else:
_adapter_name = AdapterName.LORA_4HtoH_ADAPTER
_adapter_cfg_cls = Lora4HtoHAdapterConfig

adapter_cfg = self._create_lora_config(
cfg, lora_cfg, cfg.ffn_hidden_size, cfg.hidden_size, Lora4HtoHAdapterConfig
cfg, lora_cfg, cfg.ffn_hidden_size, cfg.hidden_size, _adapter_cfg_cls
)
name_key_to_cfg[AdapterName.LORA_4HtoH_ADAPTER] = adapter_cfg
name_key_to_mcore_mixins[AdapterName.LORA_4HtoH_ADAPTER] = [("mlp", MCoreMLPMixin)]
name_key_to_cfg[_adapter_name] = adapter_cfg
if _adapter_name == AdapterName.LORA_MOE_4HtoH_ADAPTER:
name_key_to_mcore_mixins[_adapter_name] = [("mlp.experts", MCoreSequentialMLPMixin)]
for i in range(int(cfg.num_moe_experts)):
name_key_to_mcore_mixins[_adapter_name].append(
(f"mlp.experts.local_experts.{i}", MCoreMLPMixin)
)
else:
name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)]
else:
logging.error(
f"Unrecognized target_module string: {module}.\n"
Expand Down Expand Up @@ -221,6 +249,8 @@ def _create_lora_config(
assert kv_channels is not None, "kv_channels must be provided for canonical Lora"
config_args.update({"num_query_groups": num_query_groups, "kv_channels": kv_channels})
config_args.pop("out_features")
elif adapter_cfg_cls in (LoraMoeHto4HAdapterConfig, LoraMoe4HtoHAdapterConfig):
config_args.update({'num_moe_experts': cfg.num_moe_experts})

if lora_cfg.weight_tying:
position_embedding_strategy = lora_cfg.get("position_embedding_strategy", None)
Expand Down
Loading