Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT extrapolatable position embedding (xpos/sandwich/alibi/kerple) and Flash Attention #6666

Merged
merged 571 commits into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
571 commits
Select commit Hold shift + click to select a range
a8564d3
move to nvidia megatron repo (#6465) (#6475)
github-actions[bot] Apr 24, 2023
7a17f73
Megatron KERPLE positional embeddings (#6478) (#6480)
github-actions[bot] Apr 24, 2023
a67b00f
Fix an invalid link in get_data.py of ljspeech (#6456)
pythinker Apr 24, 2023
1e1fbbe
1. Added external index sample. (#6462) (#6483)
github-actions[bot] Apr 25, 2023
4561e12
Update README to add core installation (#6488) (#6489)
github-actions[bot] Apr 25, 2023
599f522
Fix cache aware hybrid bugs (#6466) (#6484)
github-actions[bot] Apr 25, 2023
ae4a4dd
Fix typos (#6494) (#6495)
github-actions[bot] Apr 26, 2023
df2b870
Add disclaimer about dataset for ASR (#6496)
titu1994 Apr 26, 2023
0c85e21
fix (#6502)
Jorjeous Apr 26, 2023
24c77d0
fix broken links r1.18.0 (#6501) (#6504)
github-actions[bot] Apr 26, 2023
07f6533
[TTS] Create functions for TTS preprocessing without dataloader (#6317)
rlangman Apr 27, 2023
8bffc80
Cache aware streaming nfa (#6209)
Slyne Apr 27, 2023
6b84a8a
[BugFix] Force _get_batch_preds() to keep logits in decoder timestamp…
tango4j Apr 28, 2023
56ce2a6
[TTS] Fix FastPitch energy code (#6511)
rlangman Apr 28, 2023
b460716
fix custom forward_torch_softmax (#6512) (#6517)
github-actions[bot] Apr 28, 2023
319b191
[TTS] fixed broken path. (#6514) (#6518)
github-actions[bot] Apr 28, 2023
2dd91fa
Fix normalization of impulse response in ImpulsePerturbation (#6505)
anteju Apr 28, 2023
d0e2f5a
Add interleaved pp support (#6498)
titu1994 Apr 28, 2023
3cff6ce
Fix typos (#6523)
titu1994 May 1, 2023
c2a4264
New noise_norm perturbation based on Riva work (#6445)
trias702 May 2, 2023
669a8c2
[TTS] Add script for computing feature stats (#6508)
rlangman May 2, 2023
798978d
Add Frame-VAD model and datasets (#6441)
stevehuang52 May 2, 2023
cb53ede
Support dynamic length batches with GPT SFT (#6510)
aklife97 May 2, 2023
1217668
added back the fast emit section to the configs. (#6540) (#6542)
github-actions[bot] May 3, 2023
5090a94
removing unnessary avoid_bfloat16_autocast_context (#6481)
bmwshop May 3, 2023
b2f23bd
FC models in menu (#6473)
bmwshop May 3, 2023
6c77583
[TTS] Add tutorials for FastPitch TTS speaker adaptation with adapter…
hsiehjackson May 3, 2023
ce84b1f
[TTS] Create initial TTS dataset feature processors (#6507)
rlangman May 3, 2023
8bbc140
fix (#6529) (#6546)
github-actions[bot] May 3, 2023
dc0c332
Add FastConformer Hybrid ASR models for EN, ES, IT, DE, PL, HR, UA, B…
github-actions[bot] May 4, 2023
42691c3
Add scores for FastConformer models (#6557) (#6558)
github-actions[bot] May 4, 2023
e7f2210
Fix fp16 (#6543) (#6544)
github-actions[bot] May 4, 2023
69b2c34
Patch transcribe and support offline transcribe for hybrid model (#65…
github-actions[bot] May 4, 2023
24076ca
Fix notebook bad json (#6561)
titu1994 May 4, 2023
b41a511
Change Megatron Enc Dec model to use persistent_workers (#6548) (#6552)
github-actions[bot] May 4, 2023
77369ef
Make KenLM with PC for AggregateTokenizer and merge it (#6081)
karpnv May 4, 2023
fa62794
fix for running on 1 GPU.
khcs May 4, 2023
3817d41
temp rtd fix (#6568) (#6569)
github-actions[bot] May 4, 2023
a57ec70
[TTS] Add script for mapping speaker names to indices (#6509)
rlangman May 5, 2023
5fd9c7f
whitespace (#6574)
karpnv May 5, 2023
04c1b72
Update manifest.py for speedup (#6565) (#6573)
github-actions[bot] May 5, 2023
c13ffb9
More streaming conformer export fixes (#6567) (#6578)
github-actions[bot] May 5, 2023
846fc83
user selected max_seq_len should be less than model's max_seq_len (#6…
github-actions[bot] May 5, 2023
c19aac5
Framework for PEFT via mixins (#6391)
arendu May 5, 2023
fba50b8
cache and reuse inputs (#6422) (#6452)
github-actions[bot] May 7, 2023
d0785d5
Add patches for Virtual Parallel conversion (#6589)
titu1994 May 8, 2023
c7f58d8
Pass `.scale` instead of scaler object to core (#6551)
github-actions[bot] May 8, 2023
58440fb
Documentation for ASR-TTS models (#6594) (#6595)
github-actions[bot] May 8, 2023
aa2b9b8
[TTS] Fix aligner nan loss in fp32 (#6435)
hsiehjackson May 8, 2023
cf60b6c
Update SDP docs (#6485) (#6596)
github-actions[bot] May 8, 2023
3c1147f
Bug/typo fixes (#6599)
Kipok May 9, 2023
08ab1a7
Manual garbage collection with an interval (#6469) (#6482)
github-actions[bot] May 9, 2023
3ed0282
Make tensor split contiguous (#6580) (#6593)
github-actions[bot] May 9, 2023
a9d2910
[ASR] Fix for old models in change_attention_model (#6608)
sam1373 May 10, 2023
077b7f9
Update manifest.py to use os.path for get_full_path (#6598)
stevehuang52 May 10, 2023
9eed6d3
Cherry pick commits in #6601 to main (#6611)
fayejf May 10, 2023
77b9a85
Create dummy iters to satisy len checks (#6600) (#6603)
github-actions[bot] May 10, 2023
9f367f4
add GPT eval mode fix for interleaved to main (#6610)
aklife97 May 10, 2023
8592562
Fix batch size reconf for T5 FT for multi-validation (#6582) (#6588)
github-actions[bot] May 10, 2023
b3f5f39
Not doing CastToFloat by default (#6524) (#6563)
github-actions[bot] May 10, 2023
09f2e37
Turn autocast off when precision is fp32 (#6576)
github-actions[bot] May 10, 2023
2a446cb
update core commit hash in readme (#6622) (#6623)
github-actions[bot] May 10, 2023
2cc0f62
add hat image to docs (#6619) (#6621)
github-actions[bot] May 11, 2023
94e6e25
Allow indices exchange via distributed (#6618) (#6624)
github-actions[bot] May 11, 2023
7f48130
Offline and streaming inference support for hybrid model (#6570)
fayejf May 11, 2023
c44e3b6
Patch decoding for PC models (#6630) (#6631)
github-actions[bot] May 11, 2023
ef49b0a
Fix wer.py where 'errors' variable was not set (#6633) (#6634)
github-actions[bot] May 11, 2023
1b785e2
Restore GPT support for interleaved pipeline parallelism (#6528) (#6613)
timmoon10 May 11, 2023
44e890e
Add FA
hsiehjackson May 12, 2023
a5fcbee
Fix XPOS
hsiehjackson May 12, 2023
aedcc7c
Add warning
hsiehjackson May 12, 2023
7fbf571
Fix bugs
hsiehjackson May 13, 2023
ddb067e
Fix attention
hsiehjackson May 13, 2023
81a8c21
Fix comment
hsiehjackson May 15, 2023
36d685b
Fix cast dtype
hsiehjackson May 15, 2023
a1d1e5a
Undo xpos
hsiehjackson May 15, 2023
2eaa60a
bugfix (#6636)
fayejf May 11, 2023
5eb3552
Disable interctc tests (#6638)
Kipok May 11, 2023
4e94268
Add megatron_core to requirements (#6639) (#6640)
github-actions[bot] May 11, 2023
56847f3
Remove from jenkins (#6642)
github-actions[bot] May 11, 2023
986feed
sft model can use this script for eval (#6637)
arendu May 12, 2023
6d2c969
[TTS] Fix TTS audio preprocessing bugs (#6628)
rlangman May 12, 2023
954d43f
Move black parameters to pyproject.toml (#6647)
artbataev May 12, 2023
11c58f3
ASR-TTS Models: Support hybrid RNNT-CTC, improve docs. (#6620)
artbataev May 12, 2023
db7d578
fix conversion and eval (#6648)
arendu May 13, 2023
acb2c56
Confidence ensembles implementation (#6614)
Kipok May 15, 2023
1b28a7b
Patch memory used for NeMo Megatron models (#6615)
titu1994 May 15, 2023
6fb6e47
handle artifacts when path is dir (#6658)
arendu May 16, 2023
4ccba61
remove upgrading setuptools in reinstall.sh (#6659)
XuesongYang May 16, 2023
82d5d58
merge lora weights into base model (#6597)
arendu May 16, 2023
89b428c
upgrade to 23.04 (#6660)
ericharper May 16, 2023
9683d02
Merge r1.18.0 bugfixes and doc updates to main (#6655)
ericharper May 16, 2023
c648d99
Confidence ensembles: fix issues and add tuning functionality (#6657)
Kipok May 16, 2023
f736f60
[TTS] Implement new TextToSpeech dataset (#6575)
rlangman May 16, 2023
4e7afbb
Dialogue dataset (#6654)
yidong72 May 16, 2023
7e62925
Add support for RNNT/hybrid models to partial transcribe (#6609)
stevehuang52 May 16, 2023
e009385
eval_beamsearch_ngram.py with hybrid ctc (#6656)
karpnv May 17, 2023
c5e229a
fix bucketing bug issue for picking new bucket (#6663)
nithinraok May 17, 2023
9d7d0b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 17, 2023
b739a5e
Add t5 flash-attention
hsiehjackson May 18, 2023
473ff20
PE refactor (#6673)
hsiehjackson May 18, 2023
4a0699d
Add singleton alibi
hsiehjackson May 18, 2023
9cfea92
Fix FA mask
hsiehjackson May 18, 2023
8c3bfbd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 18, 2023
9d01255
singleton PE
hsiehjackson May 18, 2023
8a6e294
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 18, 2023
8bd1466
Fix attn bias inference
hsiehjackson May 22, 2023
0e02478
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 18, 2023
8ed6a0a
fix eval
ekmb May 19, 2023
a6b856c
[TTS] Add callback for saving audio during FastPitch training (#6665)
rlangman May 18, 2023
213b5a3
update batch size recommendation to min 32 for 43b (#6675)
Zhilin123 May 18, 2023
1b93141
Make Note usage consistent in adapter_mixins.py (#6678)
BrianMcBrayer May 18, 2023
d2938b9
Fix masking bug for TTS Aligner (#6677)
redoctopus May 18, 2023
1564d94
[ASR] Adding ssl config for fast-conformer (#6672)
krishnacpuvvada May 19, 2023
82f863b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 19, 2023
8b55842
Fix xpos offset
hsiehjackson May 23, 2023
fbdd7fe
Fix sequence parallel
hsiehjackson May 24, 2023
8535a6a
Fix parallel
hsiehjackson May 24, 2023
873f2e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2023
7847a54
Uncomment correct bias size
hsiehjackson May 24, 2023
4aa46d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2023
7514bf4
Remove unused module
hsiehjackson May 25, 2023
9a133d0
Fix singleton tril
hsiehjackson May 25, 2023
5ce3819
Fix kerple/sandwitch rename xpos
hsiehjackson May 25, 2023
bbee276
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 25, 2023
de61214
fix sandwich
hsiehjackson May 25, 2023
dcab11e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 25, 2023
cd3bb6d
Add unitest
hsiehjackson May 30, 2023
4fac042
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 25, 2023
129e55d
Fix bug
hsiehjackson May 30, 2023
3b5ec97
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2023
b2eb222
Add requirements
hsiehjackson May 30, 2023
c73f983
Remove requirements
hsiehjackson May 30, 2023
06ce313
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2023
8c969fe
Remove requirement flash-attn
hsiehjackson May 30, 2023
f70cc3f
Fix FA causal for inference
hsiehjackson Jun 1, 2023
a0cea83
Add experimental PE
hsiehjackson Jun 1, 2023
c7c6a1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2023
6876703
Update all invalid tree references to blobs for NeMo samples (#6679)
BrianMcBrayer May 19, 2023
6c65625
Update README.rst about container (#6686)
fayejf May 19, 2023
456153a
Fix a bug, use _ceil_to_nearest instead as _round_to_nearest is not d…
github-actions[bot] May 20, 2023
69992d6
Enable ONNX export of 5B GPT trained with TE FP8 modules (#6458)
asfiyab-nvidia May 22, 2023
4ee6d8f
[TTS] Add script for text preprocessing (#6541)
rlangman May 22, 2023
c856936
[TTS] Fix adapter duration issue (#6697)
hsiehjackson May 22, 2023
b70dbf7
karpnv/issues6690 (#6705)
karpnv May 23, 2023
1a66d30
Limit codeql scope (#6710)
titu1994 May 23, 2023
ff772f7
eval fix (#6685)
arendu May 23, 2023
2231a57
Fix k2 installation in Docker with CUDA 12 (#6707) (#6709)
github-actions[bot] May 24, 2023
8b3dce5
[TTS] Filter out silent audio files during preprocessing (#6716)
rlangman May 24, 2023
963855b
not pinning version (#6680)
yidong72 May 24, 2023
b0f33f1
Tutorial fixes (#6717) (#6718)
github-actions[bot] May 24, 2023
a4ef711
preprocess squad in sft format (#6727)
arendu May 25, 2023
da5e6f8
Fix Codeql (#6731)
titu1994 May 25, 2023
2c35e0b
[TTS] fix inconsistent type hints for IpaG2p (#6733)
XuesongYang May 26, 2023
2bac13d
VP Fixes for converter + Config management (#6698)
titu1994 May 26, 2023
5831405
Graph RNNT: Grid- and Compose-Transducer. W-Transducer loss (#6168)
artbataev May 26, 2023
2e963da
Fix fastpitch test nightly (#6730)
hsiehjackson May 26, 2023
7f83283
Fix for interctc test random failure (#6644)
Kipok May 26, 2023
599c503
check for first or last stage (#6708) (#6743)
github-actions[bot] May 27, 2023
0725b2d
sharded manifests docs (#6751)
bmwshop May 29, 2023
bdeab5b
[TTS] relax hardcoded prefix for phonemes and tones and infer phoneme…
XuesongYang May 30, 2023
146371b
[TTS] corrected misleading deprecation warnings. (#6702)
XuesongYang May 30, 2023
8f43ae3
Bug fix to restore act ckpt (#6753) (#6755)
github-actions[bot] May 31, 2023
7daad62
Bug fix to reset sequence parallelism (#6756) (#6770)
github-actions[bot] May 31, 2023
49e016e
Fix TTS adapter tutorial (#6741)
hsiehjackson May 31, 2023
34f5452
Fix checkpointed forward and add test for full activation checkpointi…
github-actions[bot] May 31, 2023
c022acb
lora notebook (#6765)
arendu May 31, 2023
e98f425
Fix Links (#6777) (#6778)
github-actions[bot] May 31, 2023
bcb3fd3
Remove alibi tril
hsiehjackson Jun 1, 2023
71bff2f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2023
2e6eba5
Add flash-attn requirement
hsiehjackson Jun 1, 2023
424a15d
revert sft dataset changes
ekmb Jun 1, 2023
e79a35a
Move flash-attn requirement
hsiehjackson Jun 1, 2023
4c953aa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2023
0b18768
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2023
8863360
Add install
hsiehjackson Jun 1, 2023
2dc0418
peft eval directly from ckpt (#6785)
arendu Jun 1, 2023
1353aca
Add Frame-VAD examples and utils (#6463)
stevehuang52 Jun 1, 2023
b8d19b2
[TTS][zh] refine hardcoded lowercase for ASCII letters. (#6781)
XuesongYang Jun 2, 2023
7ad325d
Revert evaluation
hsiehjackson Jun 2, 2023
b875a78
Revert evaluation
hsiehjackson Jun 2, 2023
1f229c0
Fix
hsiehjackson Jun 2, 2023
26dbc9f
Fix gpu
hsiehjackson Jun 2, 2023
a3cf08e
Spellchecking ASR customization model (#6179)
bene-ges Jun 2, 2023
90ef33a
Fix test
hsiehjackson Jun 2, 2023
380a6f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2023
b69cbf7
Fix device
hsiehjackson Jun 2, 2023
de52c2d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2023
8dc863b
Fix conflict
hsiehjackson Jun 2, 2023
29c3cd4
Merge branch 'main' into gpt-alibi-FA
hsiehjackson Jun 2, 2023
e782202
Revert
hsiehjackson Jun 2, 2023
7f40a05
Merge branch 'gpt-alibi-FA' of https://github.com/NVIDIA/NeMo into gp…
hsiehjackson Jun 2, 2023
d814f47
clean
hsiehjackson Jun 2, 2023
65118c4
Change device
hsiehjackson Jun 2, 2023
89d4547
Change device
hsiehjackson Jun 2, 2023
9c50e29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2023
218ffa3
Merge branch 'main' into gpt-alibi-FA
hsiehjackson Jun 2, 2023
84acce0
Add test FA
hsiehjackson Jun 5, 2023
874f992
Merge branch 'gpt-alibi-FA' of https://github.com/NVIDIA/NeMo into gp…
hsiehjackson Jun 5, 2023
35ac850
Merge branch 'main' into gpt-alibi-FA
hsiehjackson Jun 5, 2023
98783ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 5, 2023
08dbd86
Add CI
hsiehjackson Jun 5, 2023
bdfe61e
Merge branch 'gpt-alibi-FA' of https://github.com/NVIDIA/NeMo into gp…
hsiehjackson Jun 5, 2023
6df2df8
Fix yaml order
hsiehjackson Jun 5, 2023
1f460d9
Test random attention mask
hsiehjackson Jun 5, 2023
01f4391
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 5, 2023
23634bf
Add install FA for tests
hsiehjackson Jun 6, 2023
4cfb2da
Merge branch 'gpt-alibi-FA' of https://github.com/NVIDIA/NeMo into gp…
hsiehjackson Jun 6, 2023
528c416
cherry pick 6788 (#6816)
ekmb Jun 6, 2023
a751928
Merge branch 'gpt-alibi-FA' of https://github.com/NVIDIA/NeMo into gp…
hsiehjackson Jun 6, 2023
ee692d4
Merge branch 'main' into gpt-alibi-FA
hsiehjackson Jun 6, 2023
5178f6b
Support 2D mask
hsiehjackson Jun 6, 2023
45876ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2023
1c15644
add missing comp_att_mask arg
ekmb Jun 6, 2023
74da509
Merge branch 'gpt-alibi-FA' of https://github.com/NVIDIA/NeMo into gp…
ekmb Jun 6, 2023
5da1bc3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2023
fb895da
Merge branch 'main' into gpt-alibi-FA
hsiehjackson Jun 6, 2023
81d2fb0
Fix code ql
hsiehjackson Jun 6, 2023
b578ff5
Merge branch 'gpt-alibi-FA' of https://github.com/NVIDIA/NeMo into gp…
hsiehjackson Jun 6, 2023
82120c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2023
a9bb73e
Merge branch 'main' into gpt-alibi-FA
hsiehjackson Jun 6, 2023
662733b
Megatron MPT-7B Support (#6804)
trias702 Jun 7, 2023
6b18be2
Fix test triton
hsiehjackson Jun 7, 2023
bdd91d6
Merge branch 'gpt-alibi-FA' of https://github.com/NVIDIA/NeMo into gp…
hsiehjackson Jun 7, 2023
92e7dba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2023
bb89e61
Update FA in CI
hsiehjackson Jun 7, 2023
672f262
Merge branch 'gpt-alibi-FA' of https://github.com/NVIDIA/NeMo into gp…
hsiehjackson Jun 7, 2023
2a526ad
Fix Jenkin error
hsiehjackson Jun 7, 2023
0ac5374
Resume with FA
hsiehjackson Jun 7, 2023
7acf5cf
Follow comments
hsiehjackson Jun 7, 2023
cdff779
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2023
dcab29d
Merge branch 'main' into gpt-alibi-FA
hsiehjackson Jun 7, 2023
aba44ae
Fix README
hsiehjackson Jun 7, 2023
7c0a530
Merge branch 'gpt-alibi-FA' of https://github.com/NVIDIA/NeMo into gp…
hsiehjackson Jun 7, 2023
194b4bb
Fix README
hsiehjackson Jun 7, 2023
fe173c1
Remove torch.cuda
hsiehjackson Jun 7, 2023
7c38447
Merge branch 'main' into gpt-alibi-FA
hsiehjackson Jun 7, 2023
1104174
Remove unused import
hsiehjackson Jun 7, 2023
a3010bd
Merge branch 'gpt-alibi-FA' of https://github.com/NVIDIA/NeMo into gp…
hsiehjackson Jun 7, 2023
81b002e
Merge branch 'main' into gpt-alibi-FA
hsiehjackson Jun 8, 2023
a883aa2
kerple init
hsiehjackson Jun 8, 2023
6a895f0
Merge branch 'gpt-alibi-FA' of https://github.com/NVIDIA/NeMo into gp…
hsiehjackson Jun 8, 2023
0504814
Merge branch 'main' into gpt-alibi-FA
hsiehjackson Jun 8, 2023
889dec6
Add TE comment
hsiehjackson Jun 8, 2023
fd2899a
Merge branch 'gpt-alibi-FA' of https://github.com/NVIDIA/NeMo into gp…
hsiehjackson Jun 8, 2023
7255e31
Merge branch 'main' into gpt-alibi-FA
hsiehjackson Jun 8, 2023
c972553
Merge branch 'main' into gpt-alibi-FA
hsiehjackson Jun 9, 2023
b8b5611
Fix error when inference.compute_attention_mask=False
hsiehjackson Jun 9, 2023
83ef08d
Merge branch 'gpt-alibi-FA' of https://github.com/NVIDIA/NeMo into gp…
hsiehjackson Jun 9, 2023
498ec3d
Merge branch 'main' into gpt-alibi-FA
michalivne Jun 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ model:
transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer']
openai_gelu: False # Use OpenAI's GELU instead of the default GeLU
normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sqrt(hidden_size_per_head). This arg is provided as a configuration option mostly for compatibility with models that have been weight-converted from HF. You almost always want to se this to True.
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'rope']
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'xpos', 'sandwich']
rotary_percentage: 1.0 # If using position_embedding_type=rope, then the per head dim is multiplied by this.
attention_type: 'multihead' # Attention type. Options ['multihead']
share_embeddings_and_output_weights: True # Share embedding and output layer weights.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
use_flash_attention=False,
):
super(GPTModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -239,6 +240,7 @@ def __init__(
fp8_amax_compute_algo=fp8_amax_compute_algo,
reduce_amax=reduce_amax,
use_emha=use_emha,
use_flash_attention=use_flash_attention,
)

if self.share_embeddings_and_output_weights:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def model_provider_func(self, pre_process, post_process):
fp8_amax_compute_algo=self.cfg.get('fp8_amax_compute_algo', 'most_recent'),
reduce_amax=self.cfg.get('reduce_amax', True),
use_emha=self.cfg.get('use_emha', False),
use_flash_attention=self.cfg.get('use_flash_attention', False),
)

return model
Expand Down
316 changes: 211 additions & 105 deletions nemo/collections/nlp/modules/common/megatron/attention.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import math

import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter

from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults

try:
from apex.transformer import parallel_state, tensor_parallel
from apex.transformer.enums import AttnMaskType, AttnType
from apex.transformer.utils import divide as safe_divide

HAVE_APEX = True

except (ImportError, ModuleNotFoundError):

HAVE_APEX = False

# fake missing classes with None attributes
ModelType = AttnMaskType = AttnType = LayerType = ApexGuardDefaults()


def get_kerple_log_params(
num_attention_heads,
precision
):

try:
model_parallel_size = parallel_state.get_tensor_model_parallel_world_size()
except:
model_parallel_size = 1
num_heads_per_partition = safe_divide(num_attention_heads, model_parallel_size)

dtype_dict = {16: torch.float16, 32: torch.float32, 'bf16': torch.bfloat16}

def get_parameter(scale, init_method):
if init_method == 'ones':
return Parameter(torch.ones(
num_heads_per_partition,
device=torch.cuda.current_device(),
dtype=(lambda x, y: x[y])(dtype_dict, precision),
)[:,None,None]*scale )
elif init_method == 'uniform':
return Parameter(torch.rand(
num_heads_per_partition,
device=torch.cuda.current_device(),
dtype=(lambda x, y: x[y])(dtype_dict, precision),
)[:,None,None]*scale )

bias_p = get_parameter(2, 'uniform')
bias_a = get_parameter(1, 'uniform')

return torch.concat((bias_p, bias_a))


def kerple_log_forward(
seq_len_q, seq_len_k, relative_position_bias
):
bias_p, bias_a = torch.split(
relative_position_bias, relative_position_bias.size(0)//2, dim=0)

eps = 1e-2

# We may be able to save this and avoid recomputing this every time like in the
# reference implementation.
# Currently kept this way to be compatible with the checkpointed-attn-forward
# TODO: find a way to avoid recomputing this every time.
diff = torch.tril(
torch.arange(seq_len_k, device=x.device).view(seq_len_k, 1).repeat(1, seq_len_k)
+ torch.arange(0, -seq_len_k, -1, device=x.device)
)
diff = diff.to(x.dtype)

bias_p.data = bias_p.data.clamp(min=eps)
bias_a.data = bias_a.data.clamp(min=eps)
bias = -bias_p*torch.log(1+bias_a*diff) # log kernel

if seq_len_q != seq_len_k:
# In the train case x has dimensionality [b, np, sq, sk] with sq == sk
# The number of query tokens is equal to the number of key tokens
# At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence)
# In this case we use the appropriate token index of the cache matrix.
# As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used
assert (
seq_len_q == 1
), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1"

if type(bias) != float:
# seq_len_k - 1 points to the last token index in the current inference batch.
bias = bias[:, seq_len_k - 1, :].view(bias.shape[0], 1, bias.shape[2])

return bias
51 changes: 49 additions & 2 deletions nemo/collections/nlp/modules/common/megatron/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.rotary_pos_embedding import RotaryEmbedding
from nemo.collections.nlp.modules.common.megatron.transformer import ParallelTransformer
from nemo.collections.nlp.modules.common.megatron.alibi_relative_position_embedding import (
ALiBiRelativePositionEmbedding,
)
from nemo.collections.nlp.modules.common.megatron.kerple_relative_position_embedding import (
KERPLERelativePositionEmbedding,
)
from nemo.collections.nlp.modules.common.megatron.utils import (
ApexGuardDefaults,
get_linear_layer,
Expand Down Expand Up @@ -114,6 +120,7 @@ def get_language_model(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
use_flash_attention=False,
):
"""Build language model and return along with the key to save."""

Expand Down Expand Up @@ -188,6 +195,7 @@ def get_language_model(
fp8_amax_compute_algo=fp8_amax_compute_algo,
reduce_amax=reduce_amax,
use_emha=use_emha,
use_flash_attention=use_flash_attention,
)
# key used for checkpoints.
language_model_key = 'language_model'
Expand Down Expand Up @@ -487,6 +495,7 @@ def __init__(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
use_flash_attention=False,
):
super(TransformerLanguageModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -538,6 +547,30 @@ def __init__(
rotary_dim = int(rotary_dim * rotary_percentage)
self.rotary_pos_emb = RotaryEmbedding(rotary_dim)

elif position_embedding_type == 'alibi':
# TODO: If this is used for encoder-decodemax_position_embeddingsr model, implement proper logic and following
# addition for decoder. Currently it is only used for decoder model only.
# Encoder-decoder model, such as T5 is implemented in token_level_encoder_decoder.py
self.encoder_relative_position_embedding = ALiBiRelativePositionEmbedding(
bidirectional=False,
num_attention_heads=num_attention_heads,
layer_type=LayerType.encoder,
num_attention_heads_alibi=None,
max_seq_len=max_position_embeddings,
)

elif position_embedding_type == 'kerple':
# TODO: If this is used for encoder-decodemax_position_embeddingsr model, implement proper logic and following
hsiehjackson marked this conversation as resolved.
Show resolved Hide resolved
# addition for decoder. Currently it is only used for decoder model only.
# Encoder-decoder model, such as T5 is implemented in token_level_encoder_decoder.py
self.decoder_relative_position_embedding = KERPLERelativePositionEmbedding(
bidirectional=False,
num_attention_heads=num_attention_heads,
layer_type=LayerType.decoder,
num_attention_heads_kerple=None,
max_seq_len=max_position_embeddings,
)

# Transformer.
self.encoder = ParallelTransformer(
init_method=self.init_method,
Expand Down Expand Up @@ -588,6 +621,8 @@ def __init__(
fp8_amax_compute_algo=fp8_amax_compute_algo,
reduce_amax=reduce_amax,
use_emha=use_emha,
position_embedding_type=position_embedding_type,
use_flash_attention=use_flash_attention,
)
self._encoder_key = 'encoder'

Expand Down Expand Up @@ -627,6 +662,8 @@ def __init__(
activations_checkpoint_granularity=activations_checkpoint_granularity,
activations_checkpoint_layers_per_pipeline=activations_checkpoint_layers_per_pipeline,
transformer_engine=transformer_engine,
position_embedding_type=position_embedding_type,
use_flash_attention=use_flash_attention,
)
self._decoder_key = 'decoder'

Expand Down Expand Up @@ -697,6 +734,8 @@ def forward(

# enc_attn_mask: [1, 1, s, s]

rotary_pos_emb = None
encoder_self_attention_relative_position_bias = None
if self.position_embedding_type == 'rope':
if inference_max_sequence_len is not None:
rotary_pos_emb = self.rotary_pos_emb(inference_max_sequence_len)
Expand All @@ -714,8 +753,14 @@ def forward(
)
else:
rotary_pos_emb = self.rotary_pos_emb(encoder_input.size(0))
else:
rotary_pos_emb = None
elif self.position_embedding_type == 'alibi':
enc_seq_length = enc_input_ids.size(1)
encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding(
query_seq_length=enc_seq_length, key_seq_length=enc_seq_length,
)
elif self.position_embedding_type == 'kerple':
encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding


# encoder.
if enc_hidden_states is None:
Expand All @@ -730,6 +775,8 @@ def forward(
rotary_pos_emb=(rotary_pos_emb, None, None)
if rotary_pos_emb is not None
else None, # This assumes that this being used as a GPT/BERT model only (no cross-attention)
self_attention_relative_position_bias=encoder_self_attention_relative_position_bias
if encoder_self_attention_relative_position_bias is not None else None
)
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def get_decoder_model(
moe_dropout=0.0,
turn_off_rop=False, # turn off the RoP positional embedding
version=1,
position_embedding_type='learned_absolute'
):
"""Build language model and return along with the key to save."""

Expand Down Expand Up @@ -143,6 +144,7 @@ def get_decoder_model(
num_moe_experts=num_moe_experts,
moe_frequency=moe_frequency,
moe_dropout=moe_dropout,
position_embedding_type=position_embedding_type
)
elif arch == "retro":
decoder = MegatronRetrievalTransformerDecoderModule(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def get_encoder_model(
moe_dropout=0.0,
turn_off_rop=False, # turn off the RoP positional embedding
version=1, # model version
position_embedding_type='learned_absolute'
):
"""Build language model and return along with the key to save."""

Expand Down Expand Up @@ -145,6 +146,7 @@ def get_encoder_model(
num_moe_experts=num_moe_experts,
moe_frequency=moe_frequency,
moe_dropout=moe_dropout,
position_embedding_type=position_embedding_type
)
elif arch == "retro":
encoder = MegatronRetrievalTransformerEncoderModule(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
num_moe_experts=1,
moe_frequency=1,
moe_dropout=0.0,
position_embedding_type='learned_absolute'
):
super(MegatronTransformerDecoderModule, self).__init__()

Expand Down Expand Up @@ -147,6 +148,7 @@ def __init__(
num_moe_experts=num_moe_experts,
moe_frequency=moe_frequency,
moe_dropout=moe_dropout,
position_embedding_type=position_embedding_type
)
self._model_key = 'model'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
num_moe_experts=1,
moe_frequency=1,
moe_dropout=0.0,
position_embedding_type='learned_absolute'
):
super(MegatronTransformerEncoderModule, self).__init__()

Expand Down Expand Up @@ -145,6 +146,7 @@ def __init__(
num_moe_experts=num_moe_experts,
moe_frequency=moe_frequency,
moe_dropout=moe_dropout,
position_embedding_type=position_embedding_type
)
self._model_key = 'model'

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
import torch.nn as nn
from torch.nn import functional as F

def sandwich_pos_bias(qlen, klen, hidden_size, num_attention_heads, device):
context_position = torch.arange(qlen, dtype=torch.long,
device=device)[:, None]
memory_position = torch.arange(klen, dtype=torch.long,
device=device)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)

inv_freq = 1.0 / (10000 ** (2 * torch.arange(1, hidden_size/2,
device=device) / hidden_size))

_bias = torch.sum(
relative_position[:,:,None].repeat(1,1,len(inv_freq)) * inv_freq, axis=2)
bias = _bias.repeat(num_attention_heads, 1, 1)

_bias_scales = torch.arange(1, num_attention_heads + 1, 1, device=device)
bias_scales = torch.stack(
list(map(lambda x, y: x * y, _bias_scales,
torch.ones(num_attention_heads, qlen, klen, device=device))))
scaled_bias = (bias - hidden_size / 2 ) / (bias_scales * 8 / num_attention_heads)

return scaled_bias
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def __init__(
num_moe_experts=encoder_cfg.get('num_moe_experts', 1),
moe_frequency=encoder_cfg.get('moe_frequency', 1),
moe_dropout=encoder_cfg.get('moe_dropout', 0.0),
position_embedding_type=decoder_cfg.get('position_embedding_type', 'learned_absolute')
)

if add_decoder:
Expand Down Expand Up @@ -365,6 +366,7 @@ def __init__(
num_moe_experts=decoder_cfg.get('num_moe_experts', 1),
moe_frequency=decoder_cfg.get('moe_frequency', 1),
moe_dropout=decoder_cfg.get('moe_dropout', 0.0),
position_embedding_type=decoder_cfg.get('position_embedding_type', 'learned_absolute')
)

self.enc_dec_model = MegatronTransformerEncoderDecoderModule(
Expand Down
Loading