Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] fix num_lookahead_slots missing in async executor #4165

Merged
merged 2 commits into from
Apr 30, 2024

Conversation

leiwen83
Copy link
Contributor

@leiwen83 leiwen83 commented Apr 18, 2024

This fixes the following stacktrace:

handle: <Handle functools.partial(<function _raise_exception_on_finish at 0x7fed348f8d30>, error_callback=<bound method AsyncLLMEngine._error_callback of <vllm.engine.async_llm_engine.AsyncLLMEngine object at 0x7fed2a489ed0>>)>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 38, in _raise_exception_on_finish
    task.result()
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 496, in run_engine_loop
    has_requests_in_progress = await asyncio.wait_for(
  File "/usr/lib/python3.10/asyncio/tasks.py", line 445, in wait_for
    return fut.result()
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 470, in engine_step
    request_outputs = await self.engine.step_async()
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 213, in step_async
    output = await self.model_executor.execute_model_async(
  File "/usr/local/lib/python3.10/dist-packages/vllm/executor/gpu_executor.py", line 152, in execute_model_async
    output = await make_async(self.driver_worker.execute_model)(
  File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
TypeError: SpecDecodeWorker.execute_model() missing 1 required positional argument: 'num_lookahead_slots'

@leiwen83
Copy link
Contributor Author

cc @cadedaniel

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. @cadedaniel is there any relevant test that could be updated to test it?

@cadedaniel
Copy link
Collaborator

Sorry I missed this. The fix looks good. To get it merged:

  • Can we add a stacktrace with the failure exception to PR description?
  • Can we add a test to catch this in the future?
    • Straightforward way is to add a test like
      def test_spec_decode_e2e_with_detokenization(test_llm_generator,
      batch_size: int):
      """Run generation with speculative decoding on a batch. Verify the engine
      generates the correct number of tokens (via ignore_eos=True), and that the
      detokenization matches HF transformers.
      """
      output_len = 32
      temperature = 0.0
      prompts = [
      "Hello, my name is",
      "The president of the United States is",
      "The capital of France is",
      "The future of AI is",
      ]
      prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
      sampling_params = SamplingParams(
      max_tokens=output_len,
      ignore_eos=True,
      temperature=temperature,
      )
      batch_tokens, batch_token_ids = get_output_from_llm_generator(
      test_llm_generator, prompts, sampling_params)
      # Expect a generation for each prompt in the batch.
      assert len(batch_token_ids) == len(prompts)
      # Expect each generation to have expected number of tokens (note ignore_eos
      # is True).
      assert [len(token_ids)
      for token_ids in batch_token_ids] == ([output_len] * batch_size)
      # Expect detokenized string to match.
      tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
      for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids):
      expected_tokens = tok.decode(actual_token_ids)
      print(f"{actual_token_ids=}")
      assert actual_tokens.strip() == expected_tokens.strip()
      , but which uses the AsyncLLMEngine instead of LLM entrypoint

@leiwen83
Copy link
Contributor Author

leiwen83 commented Apr 23, 2024

Hi @cadedaniel ,

This issue is trigger while I am testing spec infer with openai api serving.
backtrace is as:

handle: <Handle functools.partial(<function _raise_exception_on_finish at 0x7fed348f8d30>, error_callback=<bound method AsyncLLMEngine._error_callback of <vllm.engine.async_llm_engine.AsyncLLMEngine object at 0x7fed2a489ed0>>)>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 38, in _raise_exception_on_finish
    task.result()
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 496, in run_engine_loop
    has_requests_in_progress = await asyncio.wait_for(
  File "/usr/lib/python3.10/asyncio/tasks.py", line 445, in wait_for
    return fut.result()
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 470, in engine_step
    request_outputs = await self.engine.step_async()
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 213, in step_async
    output = await self.model_executor.execute_model_async(
  File "/usr/local/lib/python3.10/dist-packages/vllm/executor/gpu_executor.py", line 152, in execute_model_async
    output = await make_async(self.driver_worker.execute_model)(
  File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
TypeError: SpecDecodeWorker.execute_model() missing 1 required positional argument: 'num_lookahead_slots'

For testcase, I'm thinking whether we could add a new interface in vllm/entrypoints/llm.py, so that could make it as Async when request? Or we may need to implement similiar things in tests folder.

@cadedaniel
Copy link
Collaborator

For testcase, I'm thinking whether we could add a new interface in vllm/entrypoints/llm.py, so that could make it as Async when request? Or we may need to implement similiar things in tests folder.

Exactly. In the ideal case we can run an async llm entrypoint, but that might be a lot of work -- it's ok to have a hackier version in a conftest somewher that allows us to get coverage of this codepath.

@leiwen83
Copy link
Contributor Author

For testcase, I'm thinking whether we could add a new interface in vllm/entrypoints/llm.py, so that could make it as Async when request? Or we may need to implement similiar things in tests folder.

Exactly. In the ideal case we can run an async llm entrypoint, but that might be a lot of work -- it's ok to have a hackier version in a conftest somewher that allows us to get coverage of this codepath.

I try add use a AsyncEngine mode LLM implment in conftest, but found it would not free the cuda memory after delete. So it would broken following tests behind.
After a lot test, I find current we may need to resort to subprocess, and execute asyncllm inside it, while destory the subprocess itself after test finish. With this ugly build, now this error could be caught in test case and other normal syncllm case is also functional.

@zxdvd
Copy link
Contributor

zxdvd commented Apr 24, 2024

@leiwen83 CPU executor add async support recently, need to fix that too.

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments regarding tests

"""


class AsyncLLM:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider supporting gpu cleanup using cleanup method in conftest.py?

I think you can make it a conftest and delete the instance & call cleanup() in the destructor simliar to

def __del__(self):


with Manager() as manager:
result = manager.dict()
p = Process(target=_test_spec_decode_e2e_with_detokenization_async,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is prone to leak. Consider;

try:
    p = Process(target=_test_spec_decode_e2e_with_detokenization_async,
                    args=(request, common_llm_kwargs,
                          per_test_common_llm_kwargs, test_llm_kwargs,
                          batch_size, seed, result))
    p.start()
    p.join()
finally:
    p.terminate()

or I recommend you to use ray (and use ray.shutdown() in the finally).

import ray
try:
    ray.init()
    @ray.remote
    def run():
        # call your thing
    ray.get(run.remote())
finally:
    ray.shutdown()

print(f"{actual_token_ids=}")
assert actual_tokens.strip() == expected_tokens.strip()

del llm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clean up as I suggested above ^?

@cadedaniel
Copy link
Collaborator

@rkooo567 's suggestions are good. Another way to do it is to look at spec decode e2e tests for inspiration:

def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
per_test_common_llm_kwargs, distinct_llm_kwargs,
seed):
kwargs = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**distinct_llm_kwargs,
}
test_name = request.node.name
def generator_inner():
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
llm = LLM(**kwargs)
set_random_seed(seed)
yield llm
del llm
cleanup()
def generator_outer():
for llm in generator_inner():
yield llm
del llm
return generator_outer

@leiwen83
Copy link
Contributor Author

@rkooo567 's suggestions are good. Another way to do it is to look at spec decode e2e tests for inspiration:

def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
per_test_common_llm_kwargs, distinct_llm_kwargs,
seed):
kwargs = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**distinct_llm_kwargs,
}
test_name = request.node.name
def generator_inner():
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
llm = LLM(**kwargs)
set_random_seed(seed)
yield llm
del llm
cleanup()
def generator_outer():
for llm in generator_inner():
yield llm
del llm
return generator_outer

simply delete llm and clean would not help reduce gpu memory usage here, which I think it is due to async mode creating a looping thread, and it would not be destory even llm itself is deleted.

I would try ray method.

@leiwen83
Copy link
Contributor Author

@leiwen83 CPU executor add async support recently, need to fix that too.

Thanks reminding, cpu execute also get fixed.

@leiwen83
Copy link
Contributor Author

@rkooo567 now testcase is reworked with ray method.

@cadedaniel async mode in e2e case now could be compatible with original parameter passing way.

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current one is okay, but I have some impression we can just reuse engine_use_ray (lmk if this is wrong!)

tests/spec_decode/e2e/conftest.py Outdated Show resolved Hide resolved
tests/spec_decode/e2e/conftest.py Show resolved Hide resolved
Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks good -- thanks!

what's up with the CI failure?

@leiwen83
Copy link
Contributor Author

leiwen83 commented Apr 30, 2024

Code looks good -- thanks!

what's up with the CI failure?

No idea... I rebase the code which retrigger the CI, seems still has similar issue, which looks like ray env cannot allocate required {'CPU': 1.0, 'GPU': 0.9} resource for

spec_decode/e2e/test_correctness.py::test_spec_decode_e2e_with_detokenization[1-1-test_llm_kwargs0-per_test_common_llm_kwargs0-common_llm_kwargs0] Creating baseline_or_test='test' LLM for test_name='test_spec_decode_e2e_with_detokenization[1-1-test_llm_kwargs0-per_test_common_llm_kwargs0-common_llm_kwargs0]'. kwargs={'model': 'JackFram/llama-68m', 'enforce_eager': True, 'use_v2_block_manager': True, 'use_async': True, 'speculative_model': 'JackFram/llama-68m', 'num_speculative_tokens': 5}

@leiwen83
Copy link
Contributor Author

@cadedaniel I find the root cause of test case in CI failed. It is because test_spec_decode_xfail_ray in tests/spec_decode/e2e/test_compatibility.py has init the ray cluster which claim the gpu resource but never release it, causing the later ray call procedure fail to reclaim the gpu as there is only one GPU in the env.

@cadedaniel
Copy link
Collaborator

cadedaniel commented Apr 30, 2024

OK, let's skip that test. We can follow up and fix later.

@pytest.mark.skip("Ray does not release GPU resources in this test")

@cadedaniel
Copy link
Collaborator

oh that works too. thanks!

@cadedaniel cadedaniel merged commit 4bb53e2 into vllm-project:main Apr 30, 2024
48 checks passed
@andysalerno
Copy link

andysalerno commented Apr 30, 2024

This may have introduced a failure on Dockerfile.cpu scenarios, cc @cadedaniel I suppose the cpu path simply needs to be updated to accept the new parameter?

python3 -m vllm.entrypoints.openai.api_server --model microsoft/Phi-3-mini-128k-instruct --trust-remote-code --max-model-len 8000
ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/uvicorn/protocols/http/httptools_impl.py", line 411, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
  File "/usr/local/lib/python3.10/dist-packages/uvicorn/middleware/proxy_headers.py", line 69, in __call__
    return await self.app(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/fastapi/applications.py", line 1054, in __call__
    await super().__call__(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/applications.py", line 123, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/errors.py", line 186, in __call__
    raise exc
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/errors.py", line 164, in __call__
    await self.app(scope, receive, _send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/cors.py", line 93, in __call__
    await self.simple_response(scope, receive, send, request_headers=headers)
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/cors.py", line 148, in simple_response
    await self.app(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/exceptions.py", line 65, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 756, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 776, in app
    await route.handle(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 297, in handle
    await self.app(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 77, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 72, in app
    response = await func(request)
  File "/usr/local/lib/python3.10/dist-packages/fastapi/routing.py", line 278, in app
    raw_response = await run_endpoint_function(
  File "/usr/local/lib/python3.10/dist-packages/fastapi/routing.py", line 191, in run_endpoint_function
    return await dependant.call(**values)
  File "/workspace/vllm/vllm/entrypoints/openai/api_server.py", line 90, in create_chat_completion
    generator = await openai_serving_chat.create_chat_completion(
  File "/workspace/vllm/vllm/entrypoints/openai/serving_chat.py", line 128, in create_chat_completion
    return await self.chat_completion_full_generator(
  File "/workspace/vllm/vllm/entrypoints/openai/serving_chat.py", line 290, in chat_completion_full_generator
    async for res in result_generator:
  File "/workspace/vllm/vllm/engine/async_llm_engine.py", line 663, in generate
    raise e
  File "/workspace/vllm/vllm/engine/async_llm_engine.py", line 657, in generate
    async for request_output in stream:
  File "/workspace/vllm/vllm/engine/async_llm_engine.py", line 77, in __anext__
    raise result
  File "/workspace/vllm/vllm/engine/async_llm_engine.py", line 38, in _raise_exception_on_finish
    task.result()
  File "/workspace/vllm/vllm/engine/async_llm_engine.py", line 498, in run_engine_loop
    has_requests_in_progress = await asyncio.wait_for(
  File "/usr/lib/python3.10/asyncio/tasks.py", line 445, in wait_for
    return fut.result()
  File "/workspace/vllm/vllm/engine/async_llm_engine.py", line 472, in engine_step
    request_outputs = await self.engine.step_async()
  File "/workspace/vllm/vllm/engine/async_llm_engine.py", line 213, in step_async
    output = await self.model_executor.execute_model_async(
  File "/workspace/vllm/vllm/executor/cpu_executor.py", line 114, in execute_model_async
    output = await make_async(self.driver_worker.execute_model)(
  File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
TypeError: CPUWorker.execute_model() got an unexpected keyword argument 'num_lookahead_slots'

@cadedaniel
Copy link
Collaborator

@andysalerno do you have a repro script?

@cadedaniel
Copy link
Collaborator

err, why is this not showing up in CI

@leiwen83 leiwen83 deleted the num_lookahead_slots_fix branch May 2, 2024 08:58
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 6, 2024
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 7, 2024
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants