From 17fad4caf279a009fc063655499320682b7e67cc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 14 Sep 2024 09:46:04 -0700 Subject: [PATCH] [torch.compile] fix functionalization (#8480) --- tests/compile/test_full_graph.py | 13 ++- vllm/compilation/backends.py | 156 +++++++++++++++++++++++++++++++ vllm/worker/model_runner.py | 3 +- 3 files changed, 167 insertions(+), 5 deletions(-) create mode 100644 vllm/compilation/backends.py diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 0a6e781e18834..43905082b7caf 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -16,7 +16,12 @@ def test_full_graph(model): "The future of AI is", ] sampling_params = SamplingParams(temperature=0) - llm = LLM(model="meta-llama/Meta-Llama-3-8B", - enforce_eager=True, - load_format="dummy") - llm.generate(prompts, sampling_params) + llm = LLM(model=model, enforce_eager=True) + + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py new file mode 100644 index 0000000000000..de0b1d8a75757 --- /dev/null +++ b/vllm/compilation/backends.py @@ -0,0 +1,156 @@ +import operator + +import torch +import torch.fx as fx + + +def fix_functionalization(graph: fx.Graph): + """ + Rewrite the graph module to replace the pattern involving + torch._higher_order_ops.auto_functionalize.auto_functionalized + with a direct call to the inplace custom op. + + # TODO: check if PyTorch nightly has fixed this issue + """ + + # debug code, if we want to see the graph before the transformation + # with open("before.py", "w") as f: + # print(graph.python_code(root_module="self", verbose=True).src, file=f) + + nodes_to_remove = [] + + for node in graph.nodes: + # Identify the auto_functionalized node + if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa + if node.args[0] == torch.ops._C.rotary_embedding.default: + # manual replace for rotary_embedding + + # Now, collect the arguments + kwargs = node.kwargs + + query = kwargs['query'] + mm_node = query.args[0].args[0] + + # Create a new call to torch.ops._C.rotary_embedding.default + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function(torch.ops._C.rotary_embedding.default, + kwargs=kwargs) + + # Remove the auto_functionalized node + # Since the node may have outputs, we need to handle its users + # Replace uses of the outputs (getitem nodes) with mm_node + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + # Remove the getitem node + for getitem_user in list(user.users): + if (getitem_user.op == 'call_function' + and getitem_user.target + == torch.ops.aten.slice_scatter.default): + # Replace the uses of slice_scatter node + # with mm_node + getitem_user.replace_all_uses_with(mm_node) + nodes_to_remove.append(getitem_user) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.fused_add_rms_norm.default: + # manual replace for fused_add_rms_norm + # this is the most effective optimization for llama + # failing to do this will result in many unnecessary copies + + kwargs = node.kwargs + + input = kwargs['input'] + residual = kwargs['residual'] + + # Create a new call to torch.ops._C.rotary_embedding.default + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs) + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + # Remove the getitem node + if user.args[1] == 1: + replace_node = input + elif user.args[1] == 2: + replace_node = residual + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.rms_norm.default: + # manual replace for rms_norm + + kwargs = node.kwargs + + input = kwargs['input'] + out = kwargs['out'] + weight = kwargs['weight'] + epsilon = kwargs['epsilon'] + # Create a new call to torch.ops._C.rotary_embedding.default + # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.rms_norm.default, + args=(out, input, weight, epsilon), + ) + + replace_node = out + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.silu_and_mul.default: + # manual replace for silu_and_mul + + kwargs = node.kwargs + + input = kwargs['input'] + out = kwargs['out'] + + # Create a new call to torch.ops._C.rotary_embedding.default + # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.silu_and_mul.default, + args=(out, input), + ) + replace_node = out + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + # Remove the nodes all at once + for node in nodes_to_remove: + graph.erase_node(node) + + # debug code, if we want to see the graph after the transformation + # with open("after.py", "w") as f: + # print(graph.python_code(root_module="self", verbose=True).src, file=f) + + +def vllm_backend(graph, example_inputs): + from torch._inductor import config + current_config = config.shallow_copy_dict() + from torch._inductor.compile_fx import compile_fx + current_config['post_grad_custom_post_pass'] = fix_functionalization + return compile_fx(graph, example_inputs, config_patches=current_config) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a49690541c68c..4efcdf78f4039 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1064,8 +1064,9 @@ def load_model(self) -> None: "This may lead to less accurate results!") if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): + from vllm.compilation.backends import vllm_backend from vllm.plugins import get_torch_compile_backend - backend = get_torch_compile_backend() or "eager" + backend = get_torch_compile_backend() or vllm_backend self.model = torch.compile( self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,