Skip to content

Commit

Permalink
Add pipeline support to eval methods (#3684)
Browse files Browse the repository at this point in the history
* update complete method for pipeline

Signed-off-by: ericharper <complex451@gmail.com>

* broadcast tokens

Signed-off-by: ericharper <complex451@gmail.com>

* print args.prompt and response

Signed-off-by: ericharper <complex451@gmail.com>

* print response

Signed-off-by: ericharper <complex451@gmail.com>

* update compute_logprobs

Signed-off-by: ericharper <complex451@gmail.com>

* style

Signed-off-by: ericharper <complex451@gmail.com>

* remove imports

Signed-off-by: ericharper <complex451@gmail.com>
  • Loading branch information
ericharper authored Feb 17, 2022
1 parent 6dd4263 commit 31cf580
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 15 deletions.
2 changes: 2 additions & 0 deletions examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def pad_collate(batch):
print("***************************")
print(response)
print("***************************")
if args.prompt and not args.compute_logprobs:
print(f'Prompt: {args.prompt}\n\nResponse: {response[0][0][0]}')


if __name__ == '__main__':
Expand Down
121 changes: 106 additions & 15 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
forward_backward_pipelining_without_interleaving,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import get_num_microbatches, _reconfigure_microbatch_calculator

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
Expand Down Expand Up @@ -386,6 +386,20 @@ def loss_func(output_tensor):

return fwd_output_and_loss_func

def get_forward_output_only_func(self):
def fwd_output_only_func(batch, model):
batch = [x.cuda() for x in batch]
tokens, attention_mask, position_ids = batch
attention_mask = attention_mask[0:1]
output_tensor = model(tokens, position_ids, attention_mask)

def id_func(output_tensor):
return output_tensor, {'logits': output_tensor}

return output_tensor, id_func

return fwd_output_only_func

def validation_step(self, batch, batch_idx):
"""
Our dataloaders produce a micro-batch and then we fetch
Expand Down Expand Up @@ -828,9 +842,14 @@ def complete(self, request: Dict, positions: List, tokens_to_generate: int):
* offsets: list of tokens start positions in text
"""
# TODO: Add raise with message to use previous commit ID / version of NeMo
if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
raise ValueError('complete method is not yet supported for pipeline')
app_state = AppState()
_reconfigure_microbatch_calculator(
rank=app_state.global_rank,
rampup_batch_size=None,
global_batch_size=1,
micro_batch_size=1,
data_parallel_size=1,
)

results = []
request_tokens = request["tokens"]
Expand All @@ -839,6 +858,8 @@ def complete(self, request: Dict, positions: List, tokens_to_generate: int):

# For prompt tuned GPT models
if self.use_soft_prompts:
if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
raise ValueError('complete method is not yet supported for pipeline with soft prompts')
prompt_tags = request["prompt_tags"][idx]
else:
prompt_tags = None
Expand Down Expand Up @@ -871,26 +892,54 @@ def complete(self, request: Dict, positions: List, tokens_to_generate: int):
eod_mask_loss=self.cfg.get('eod_mask_loss', False),
)

batch = [tokens, attention_mask, position_ids]
tensor_shape = [tokens.shape[1], 1, self.cfg.hidden_size]
if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
output_tensor = forward_backward_pipelining_without_interleaving(
forward_step_func=self.get_forward_output_only_func(),
batch=batch,
model=self.model,
forward_only=True,
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
)
else:
output_tensor = forward_backward_no_pipelining(
forward_step_func=self.get_forward_output_only_func(),
batch=batch,
model=self.model,
forward_only=True,
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
)

# get output tensor
# No labels during inference. Still need masks to not attend to the right
output_tensor = self(tokens, position_ids, attention_mask, prompt_tags=prompt_tags, labels=None)
output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor)
if parallel_state.is_pipeline_last_stage():
output_tensor = output_tensor[0]['logits']
output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor)

log_probs, token_ids = torch.max(logsoftmaxlayer(output_tensor), dim=-1)
tokens = torch.cat([tokens, torch.unsqueeze(token_ids[:, -1], 1)], dim=1)
else:
log_probs = torch.zeros((tokens.shape[0], tokens.shape[1]), dtype=torch.float).cuda()
tokens = torch.zeros((tokens.shape[0], tokens.shape[1] + 1), dtype=tokens.dtype).cuda()

log_probs, token_ids = torch.max(logsoftmaxlayer(output_tensor), dim=-1)
reached_eos = token_ids[0, -1].item() == self.tokenizer.eos_id
tokens = torch.cat([tokens, torch.unsqueeze(token_ids[:, -1], 1)], dim=1)
torch.distributed.broadcast(tokens, get_last_rank())
torch.distributed.broadcast(log_probs, get_last_rank())

# add to results as (text, tokens, log_probs, offsets)
for token, prob in zip(tokens, log_probs.tolist()):
results.append(
(self.tokenizer.ids_to_text(token[:-1]), self.tokenizer.ids_to_tokens(token[:-1]), prob, [0])
(self.tokenizer.ids_to_text(token[:-1]), self.tokenizer.ids_to_tokens(token[:-1]), prob, [0],)
)

# offsets calculation
for item in results:
for index, token in enumerate(item[1]):
if index != len(item[1]) - 1:
item[3].append(len(token) + item[3][-1])
# returnprompts in order they were inputted

# return prompts in the order that they were input
response = [0 for i in range(len(positions))]
for item, index in zip(results, positions):
response[index] = item
Expand All @@ -914,13 +963,23 @@ def compute_logprobs(self, request: Dict, positions: List):
* log_probs: list of log_softmax's from output_tensor in respect to text tokens
* offsets: list of tokens start positions in text
"""
app_state = AppState()
_reconfigure_microbatch_calculator(
rank=app_state.global_rank,
rampup_batch_size=None,
global_batch_size=1,
micro_batch_size=1,
data_parallel_size=1,
)

results = []
request_tokens = request["tokens"]
logsoftmaxlayer = torch.nn.LogSoftmax(dim=-1)
for idx, tokens in enumerate(request_tokens):
tokens_cut = tokens[:, :-1]
# For prompt tuned GPT models
if self.use_soft_prompts:
if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
raise ValueError('compute_logprobs method is not yet supported for pipeline with soft prompts')
prompt_tags = request["prompt_tags"][idx]
else:
prompt_tags = None
Expand All @@ -947,8 +1006,39 @@ def compute_logprobs(self, request: Dict, positions: List):
reset_attention_mask=self.cfg.get('reset_attention_mask', False),
eod_mask_loss=self.cfg.get('eod_mask_loss', False),
)
output_tensor = self(tokens_cut, position_ids, attention_mask, prompt_tags=prompt_tags, labels=None)
output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor)

batch = [tokens_cut, attention_mask, position_ids]
tensor_shape = [tokens_cut.shape[1], 1, self.cfg.hidden_size]
if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
output_tensor = forward_backward_pipelining_without_interleaving(
forward_step_func=self.get_forward_output_only_func(),
batch=batch,
model=self.model,
forward_only=True,
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
)
else:
output_tensor = forward_backward_no_pipelining(
forward_step_func=self.get_forward_output_only_func(),
batch=batch,
model=self.model,
forward_only=True,
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
)

# get output tensor
if parallel_state.is_pipeline_last_stage():
output_tensor = output_tensor[0]['logits']
output_tensor = tensor_parallel.gather_from_tensor_model_parallel_region(output_tensor)

else:
output_tensor = torch.zeros(
(tokens_cut.shape[0], tokens_cut.shape[1], self.padded_vocab_size), dtype=torch.float
).cuda()

torch.distributed.broadcast(output_tensor, get_last_rank())

log_probs = []
for output in output_tensor:
Expand All @@ -958,6 +1048,7 @@ def compute_logprobs(self, request: Dict, positions: List):

for token, prob in zip(tokens, log_probs):
results.append((self.tokenizer.ids_to_text(token), self.tokenizer.ids_to_tokens(token), prob, [0]))

# offsets calculation
for item in results:
for index, token in enumerate(item[1]):
Expand Down

0 comments on commit 31cf580

Please sign in to comment.