Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference kv cache support for transformer TE path #6627

Merged
merged 10 commits into from
Jun 6, 2023
23 changes: 22 additions & 1 deletion examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ def __getitem__(self, idx):
return self.sentences[idx]


def remove_padded_prompts(response, nb_paddings):
result = {}
for k, v in response.items():
if v != None and (type(v) is list or type(v) is torch.Tensor):
v = v[:-nb_paddings]
result[k] = v
return result


@hydra_runner(config_path="conf", config_name="megatron_gpt_inference")
def main(cfg) -> None:

Expand Down Expand Up @@ -254,22 +263,34 @@ def main(cfg) -> None:
"compute_logprob": cfg.inference.compute_logprob,
}

fp8_enabled = hasattr(model.cfg, "fp8") and (model.cfg.fp8 == True)
if fp8_enabled:
nb_paddings = 0
while len(cfg.prompts) % 8 != 0:
cfg.prompts.append("")
nb_paddings += 1

# First method of running text generation, call model.generate method
response = model.generate(
inputs=OmegaConf.to_container(cfg.prompts), length_params=length_params, sampling_params=sampling_params
)

if fp8_enabled:
response = remove_padded_prompts(response, nb_paddings)
print("***************************")
print(response)
print("***************************")

# Second method of running text generation, call trainer.predict [recommended]
bs = 8 if fp8_enabled else 2
ds = RequestDataSet(OmegaConf.to_container(cfg.prompts))
request_dl = DataLoader(dataset=ds, batch_size=2)
request_dl = DataLoader(dataset=ds, batch_size=bs)
config = OmegaConf.to_container(cfg.inference)
model.set_inference_config(config)
response = trainer.predict(model, request_dl)

if fp8_enabled:
response[-1] = remove_padded_prompts(response[-1], nb_paddings)
print("***************************")
print(response)
print("***************************")
Expand Down
26 changes: 21 additions & 5 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,9 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.multi_query_attention = multi_query_attention

self.inference_current_sequence_len = 0
self.inference_params = None

self.activations_checkpoint_method = activations_checkpoint_method
self.activations_checkpoint_num_layers = activations_checkpoint_num_layers
self.activations_checkpoint_granularity = activations_checkpoint_granularity
Expand Down Expand Up @@ -1451,6 +1454,20 @@ def forward(
if get_key_value:
presents = []

if self.transformer_engine:
# Pass key value information to TE through inference_params to pre-allocate memory
if set_inference_key_value_memory:
self.inference_params = type('', (), {})()
self.inference_params.max_sequence_len = inference_max_sequence_len
self.inference_params.max_batch_size = hidden_states.size(1)
self.inference_params.batch_size_offset = 0
self.inference_params.key_value_memory_dict = {}
self.inference_params.sequence_len_offset = 0
self.inference_current_sequence_len = 0

if self.inference_params != None:
self.inference_params.sequence_len_offset = self.inference_current_sequence_len

for index in range(self.num_layers):
layer = self._get_layer(index)
past = None
Expand Down Expand Up @@ -1479,19 +1496,15 @@ def forward(
checkpoint_core_attention = False

if self.transformer_engine:

inference_params = None

hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params,
inference_params=self.inference_params,
is_first_microbatch=self.is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
)

else:
hidden_states = layer(
hidden_states,
Expand All @@ -1507,6 +1520,9 @@ def forward(
cross_attention_relative_position_bias=cross_attention_relative_position_bias,
checkpoint_core_attention=checkpoint_core_attention,
)
# Update current sequence length outside of the loops
if self.transformer_engine:
self.inference_current_sequence_len += hidden_states.size(0)

# Skip counter update for eval and activation checkpointing
if torch.is_grad_enabled() and self.training:
Expand Down