From d85949a7eec9604793aff83b983aab3cda2d86a3 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 24 Jun 2024 08:51:39 -0400 Subject: [PATCH 1/7] Added MLPSpeculator e2e tests Signed-off-by: Thomas Parnell --- tests/spec_decode/e2e/test_mlp_correctness.py | 206 ++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 tests/spec_decode/e2e/test_mlp_correctness.py diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py new file mode 100644 index 0000000000000..8a1b1287a2a92 --- /dev/null +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -0,0 +1,206 @@ +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding, +and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775. +Since there is no model is needed for generate the proposal, we could make +the testcase much simpler than drafter multi-step one. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various ngram sizes / speculative sizes + +With those tests, we can say at least, ngram spec would not break the correctess +for the target model outputs. +""" + +import pytest + +from .conftest import run_greedy_equality_correctness_test + +MAIN_MODEL="meta-llama/Llama-2-13b-chat-hf" +SPEC_MODEL="ibm-fms/llama-13b-accelerator" +MAX_SPEC_TOKENS=3 + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + "dtype": "float16", + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model": MAIN_MODEL, + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 256, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_e2e_greedy_correctness(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality on a tiny model with different batch size.""" + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model": MAIN_MODEL, + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 256, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": MAIN_MODEL, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": k, + } + # Try a range of common k, as well as large speculation. + for k in range(1, 1+MAX_SPEC_TOKENS) + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that ngram speculative decoding produces exact equality + to without spec decode with many different values of k and + different ngram_prompt_lookup_max. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": MAIN_MODEL, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that ngram speculative decoding produces exact equality + to without spec decode with many different values of k and + different ngram_prompt_lookup_max. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) From da642116aa682a82044e576bf236d20109a69e8e Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 24 Jun 2024 12:12:43 -0400 Subject: [PATCH 2/7] Rework docstrings Signed-off-by: Thomas Parnell --- tests/spec_decode/e2e/test_mlp_correctness.py | 82 +++++++++++-------- 1 file changed, 47 insertions(+), 35 deletions(-) diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 8a1b1287a2a92..be8e969695622 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -9,28 +9,32 @@ numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy equality. -For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding, -and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775. -Since there is no model is needed for generate the proposal, we could make -the testcase much simpler than drafter multi-step one. - However, we still need to verify below scenario could be passed: * Batch size 1 greedy equality * Batch size >1 greedy equality * Test greedy equality under preemption - * Test greedy equality under various ngram sizes / speculative sizes + * Test greedy equality under various number of speculative tokens. -With those tests, we can say at least, ngram spec would not break the correctess -for the target model outputs. +With those tests, we can say at least, MLPSpeculator would not break the +correctess for the target model outputs. """ import pytest from .conftest import run_greedy_equality_correctness_test -MAIN_MODEL="meta-llama/Llama-2-13b-chat-hf" -SPEC_MODEL="ibm-fms/llama-13b-accelerator" -MAX_SPEC_TOKENS=3 +# main model +MAIN_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" + +# speculative model +SPEC_MODEL = "ibm-fms/llama3-8b-accelerator" + +# max. number of speculative tokens +MAX_SPEC_TOKENS = 4 + +# precision +PRECISION = "float32" + @pytest.mark.parametrize( "common_llm_kwargs", @@ -44,7 +48,8 @@ # Print spec metrics. "disable_log_stats": False, - "dtype": "float16", + # Precision + "dtype": PRECISION, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ { @@ -63,10 +68,9 @@ ]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) -def test_ngram_e2e_greedy_correctness(baseline_llm_generator, - test_llm_generator, batch_size: int, - output_len: int): - """Verify greedy equality on a tiny model with different batch size.""" +def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality with different batch size.""" run_greedy_equality_correctness_test(baseline_llm_generator, test_llm_generator, batch_size, @@ -86,7 +90,10 @@ def test_ngram_e2e_greedy_correctness(baseline_llm_generator, "enforce_eager": True, # Required for spec decode. - "use_v2_block_manager": True + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ { @@ -108,10 +115,10 @@ def test_ngram_e2e_greedy_correctness(baseline_llm_generator, ]) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) -def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator, - test_llm_generator, - batch_size: int, - output_len: int): +def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): """Verify greedy equality, even when some sequences are preempted mid- generation. """ @@ -131,7 +138,10 @@ def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator, "enforce_eager": True, # Required for spec decode. - "use_v2_block_manager": True + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -142,8 +152,8 @@ def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator, "speculative_model": SPEC_MODEL, "num_speculative_tokens": k, } - # Try a range of common k, as well as large speculation. - for k in range(1, 1+MAX_SPEC_TOKENS) + # Try a range of num. speculative tokens + for k in range(1, 1 + MAX_SPEC_TOKENS) ]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize( @@ -153,11 +163,10 @@ def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator, 32, ]) @pytest.mark.parametrize("seed", [1]) -def test_ngram_different_k(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): - """Verify that ngram speculative decoding produces exact equality - to without spec decode with many different values of k and - different ngram_prompt_lookup_max. +def test_mlp_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that mlp speculative decoding produces exact equality + to without spec decode with different values of num_speculative_tokens. """ run_greedy_equality_correctness_test(baseline_llm_generator, test_llm_generator, @@ -175,7 +184,10 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator, "enforce_eager": True, # Required for spec decode. - "use_v2_block_manager": True + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -193,11 +205,11 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator, 32, ]) @pytest.mark.parametrize("seed", [1]) -def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator, - batch_size: int, output_len: int): - """Verify that ngram speculative decoding produces exact equality - to without spec decode with many different values of k and - different ngram_prompt_lookup_max. +def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that mlp speculative decoding produces exact equality + to without spec decode when speculation is disabled for large + batch sizes. """ run_greedy_equality_correctness_test(baseline_llm_generator, test_llm_generator, From 15012a9187f3a4b88bc5895913919fea373e8632 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 24 Jun 2024 13:54:51 -0400 Subject: [PATCH 3/7] Switch to granite-3b Signed-off-by: Thomas Parnell --- tests/spec_decode/e2e/test_mlp_correctness.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index be8e969695622..2c232deb587ff 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -24,16 +24,16 @@ from .conftest import run_greedy_equality_correctness_test # main model -MAIN_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" +MAIN_MODEL = "ibm-granite/granite-3b-code-instruct" # speculative model -SPEC_MODEL = "ibm-fms/llama3-8b-accelerator" +SPEC_MODEL = "ibm-granite/granite-3b-code-instruct-accelerator" # max. number of speculative tokens -MAX_SPEC_TOKENS = 4 +MAX_SPEC_TOKENS = 5 # precision -PRECISION = "float32" +PRECISION = "float16" @pytest.mark.parametrize( From 34a1b8c97efb3f36626cabb6f3af5a54a2430041 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 24 Jun 2024 21:53:11 +0000 Subject: [PATCH 4/7] Use float32 Signed-off-by: Thomas Parnell --- tests/spec_decode/e2e/test_mlp_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 2c232deb587ff..d49a001066995 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -33,7 +33,7 @@ MAX_SPEC_TOKENS = 5 # precision -PRECISION = "float16" +PRECISION = "float32" @pytest.mark.parametrize( From ee8ce5b619470a62d09ad69cdf8b0fe26ca6c406 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 24 Jun 2024 22:40:51 +0000 Subject: [PATCH 5/7] Use bfloat16 (float32 leading to OOM) Signed-off-by: Thomas Parnell --- tests/spec_decode/e2e/test_mlp_correctness.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index d49a001066995..f7fae59062d83 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -33,7 +33,7 @@ MAX_SPEC_TOKENS = 5 # precision -PRECISION = "float32" +PRECISION = "bfloat16" @pytest.mark.parametrize( From a96d11543a83f670df4a72ff0655c2dfe4dfaa02 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 25 Jun 2024 06:54:07 -0400 Subject: [PATCH 6/7] Adjust test to avoid precision issue in float16 Signed-off-by: Thomas Parnell --- tests/spec_decode/e2e/test_mlp_correctness.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index f7fae59062d83..eeb93026f2f1f 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -33,7 +33,7 @@ MAX_SPEC_TOKENS = 5 # precision -PRECISION = "bfloat16" +PRECISION = "float16" @pytest.mark.parametrize( @@ -64,7 +64,7 @@ }, ]) @pytest.mark.parametrize("output_len", [ - 256, + 128, ]) @pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) @@ -111,7 +111,7 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, "output_len", [ # Use small output len for fast test. - 256, + 128, ]) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) From ad42323e24397f8af39bcb30f8bfddf9a8ac2d45 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Tue, 25 Jun 2024 08:21:54 -0400 Subject: [PATCH 7/7] Addressed comments Signed-off-by: Thomas Parnell --- tests/spec_decode/e2e/test_mlp_correctness.py | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index eeb93026f2f1f..9a9f2acbb8f39 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -29,7 +29,8 @@ # speculative model SPEC_MODEL = "ibm-granite/granite-3b-code-instruct-accelerator" -# max. number of speculative tokens +# max. number of speculative tokens: this corresponds to +# n_predict in the config.json of the speculator model. MAX_SPEC_TOKENS = 5 # precision @@ -50,17 +51,15 @@ # Precision "dtype": PRECISION, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [ - { + + # Main model "model": MAIN_MODEL, - }, -]) + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, }, ]) @pytest.mark.parametrize("output_len", [ @@ -94,17 +93,15 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, # Precision "dtype": PRECISION, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [ - { + + # Main model "model": MAIN_MODEL, - }, -]) + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [ { "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, }, ]) @pytest.mark.parametrize( @@ -132,8 +129,6 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, @pytest.mark.parametrize( "common_llm_kwargs", [{ - "model": MAIN_MODEL, - # Skip cuda graph recording for fast test. "enforce_eager": True, @@ -142,6 +137,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, # Precision "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -178,8 +176,6 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator, @pytest.mark.parametrize( "common_llm_kwargs", [{ - "model": MAIN_MODEL, - # Skip cuda graph recording for fast test. "enforce_eager": True, @@ -188,13 +184,15 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator, # Precision "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{ "speculative_model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, "speculative_disable_by_batch_size": 4 }]) @pytest.mark.parametrize("batch_size", [1, 5])