Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable loading ckpt for t0 finetuning #309

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open

Enable loading ckpt for t0 finetuning #309

wants to merge 16 commits into from

Conversation

Muennighoff
Copy link
Collaborator

No description provided.

megatron/utils.py Show resolved Hide resolved
megatron/utils.py Show resolved Hide resolved
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[0, 0, 0, 0, 0, 0, 0]]]]
[0, 0, 0, 0, 0, 0, 1]]]]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is a 1 , because the last row & column is 100% padding

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum I'm wondering if this doesn't screw something up. Essentially you're going to compute softmax on a row with only zeros ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last row & last col are the attention scores of the last token with respect to the last token. Since the last token is masked out in our loss_mask it doesn't matter I think.
Also it's a row with only -inf, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No you compute softmax, what should be the result of the softmax of a row full of masked out values .... It feels like that would return lots of Nans.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we fill it with -inf?
And the softmax of a row where all values are the same is just 1/n, no? Where would it cause NaNs?

Copy link
Member

@thomasw21 thomasw21 Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can try writing a test but I would be pretty sure that the actual results are 0. (with current kernel)

tools/preprocess_data.py Outdated Show resolved Hide resolved
finetune_t0_non_causal_decoder.py Outdated Show resolved Hide resolved
loss_mask *= loss_on_targets_only * loss_on_non_pad_only

attention_mask = get_packed_attention_mask(
# Run non-causal decoder
is_causal=False,
causal_mask=~(causal_mask.bool()),
is_causal=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's rename this file finetune_t0_causal_decoder then

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about just finetune_t0.py?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right but do we hardcode this everytime? I'd rather have this one be the script for causal decoder.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an argument prefixlm

finetune_t0_non_causal_decoder.py Outdated Show resolved Hide resolved
megatron/model/gpt_model.py Outdated Show resolved Hide resolved
megatron/model/gpt_model.py Outdated Show resolved Hide resolved
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No you compute softmax, what should be the result of the softmax of a row full of masked out values .... It feels like that would return lots of Nans.

* Tmp lossseq

* Efficient loss normalization

* Reuse variable

* Simplify division

* Add norm_target_loss arg

* Clarify loss on targets & remove kwarg

* Loss mask is already float

* Move norm to batch pipe

* Reshape loss mask

* Move view
Copy link
Member

@thomasw21 thomasw21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Some things I think shouldn't be in this PR.

group.add_argument('--reweight-loss-based-on-position-frequency', action="store_true",
help='Some objectives require us to sample loss_mask. This might introduce bias towards '
'specific positions. This option tries to un-bias the loss by reweighting loss on specific '
'positions based on how frequently we train on that position.'
'This is mostly used for prefix_lm training')
group.add_argument("--noise-density", type=float, default=None, help="Span corruption noise density")
group.add_argument("--mean-noise-span-length", type=int, default=None, help="Span corruption mean noise span length")
group.add_argument("--prefixlm", action='store_true', help="Whether to train a PrefixLM - To be used with finetune t0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah actually let's remove that option. I don't think we've trained one successfully. We'll probably do as people have shown that it works but in another PR IMO.

@@ -274,8 +274,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
load_dir = getattr(args, load_arg)

if args.deepspeed:
load_optimizer_states = False if args.no_load_optim else True
loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_optimizer_states=load_optimizer_states)
load_optimizer_states = not args.no_load_optim
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just use no_load_optim directly in the method

@@ -342,7 +342,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
set_checkpoint_version(state_dict.get('checkpoint_version', 0))

# Set iteration.
if args.finetune or release:
if args.finetune or release or args.reset_progress:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it that we didn't set finetune to True?

@@ -361,7 +361,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
# Check arguments.
assert args.consumed_train_samples == 0
assert args.consumed_valid_samples == 0
if 'args' in state_dict:
if 'args' in state_dict and not args.reset_progress:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment? Typically this is only used because the metadata loading mechanism screws with us.

@@ -399,7 +399,7 @@ def _build_index_mappings(
shuffle_idx_filename = _filename + '_decoder_packed_shuffle_idx.npy'

# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0:
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaik you added this code; I think it was for running tests or sth

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arf probably because I wanted to use the data loader only ... Maybe let's remove for now because we should be assuming that torch distributed is always initialized at least in Meg-DS IMO.

assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
if torch.distributed.is_initialized():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Comment on lines +107 to +108
loss_mask = loss_mask.view(-1)
loss_mask = fast_normalize(loss_mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe reshaping to the orignal structure is better API? It's better to bave the same shapes as label IMO (we still still do flatten everything)

@@ -241,7 +241,7 @@ def test_decoder_packed_mtf_dataloader(self):
last_padding_size = len([None for segment_id in items["decoder_segment_ids"][micro_batch_size - 1] if segment_id == 0])


def test_finetune_t0_non_causal_decoder_get_batch_pipe(self):
def test_finetune_t0_get_batch_pipe(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let's make it so that the script is causal decoder specific. Let's figure out non causal decoder later on.

group.add_argument('--append-bos', action='store_true',
help='Append a bos token to the end of a document.')
group.add_argument('--prepend-space', action='store_true',
help='Prepends a space to the beginning of a document')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a mention in which context it's useful, typically it is when you compute targets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants