Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchbench hf_* models fail on both TPU and GPU #6864

Closed
zpcore opened this issue Apr 1, 2024 · 5 comments · Fixed by #6869
Closed

torchbench hf_* models fail on both TPU and GPU #6864

zpcore opened this issue Apr 1, 2024 · 5 comments · Fixed by #6869
Assignees

Comments

@zpcore
Copy link
Collaborator

zpcore commented Apr 1, 2024

🐛 Bug

Torchbench models like hf_Albert, hf_Bart, hf_Bert, hf_Bert_large, hf_BigBird, hf_DistilBert, hf_GPT2, hf_GPT2_large, hf_Longformer, hf_Reformer, hf_T5, hf_T5_base, hf_T5_generate, hf_T5_large all failed recently wit hte

To Reproduce

Steps to reproduce the behavior:

cd /tmp/ && git clone https://github.com/pytorch/benchmark.git
cd /tmp/ && git clone https://github.com/pytorch/pytorch.git
cd /tmp/ && git clone https://github.com/pytorch/xla.git
cd benchmark && python install.py models hf_Bert
python xla/benchmarks/experiment_runner.py \
--suite-name torchbench    \
--accelerator cuda  \
--xla PJRT   \
--test eval  \
--repeat 2  \
--iterations-per-run 1   \
--print-subprocess \
--no-resume -k hf_Bert

Error log:

  1. with dynamo backend:
ERROR:__main__:ERROR in subprocess
INFO:__main__:Run with --model-config={"model_name": "hf_Bert_large"} --experiment-config={"accelerator": "cuda", "xla": "PJRT", "xla_flags": null, "dynamo": "openxla", "test": "eval"}
WARNING:__main__:Enabling fast F32 multiplication for PyTorch
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1711990897.946354    2258 service.cc:145] XLA service 0x5612ea54bb40 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1711990897.946448    2258 service.cc:153]   StreamExecutor device (0): NVIDIA A100-SXM4-40GB, Compute Capability 8.0
I0000 00:00:1711990897.947022    2258 se_gpu_pjrt_client.cc:853] Using BFC allocator.
I0000 00:00:1711990897.947102    2258 gpu_helpers.cc:107] XLA backend allocating 31724126208 bytes on device 0 for BFCAllocator.
I0000 00:00:1711990897.947133    2258 gpu_helpers.cc:147] XLA backend will use up to 10574708736 bytes on device 0 for CollectiveBFCAllocator.
INFO:benchmark_model:Running torch.compile with opts {'backend': 'openxla'}
Traceback (most recent call last):
  File "xla/benchmarks/experiment_runner.py", line 945, in <module>
    main()
  File "xla/benchmarks/experiment_runner.py", line 941, in main
    runner.run()
  File "xla/benchmarks/experiment_runner.py", line 61, in run
    self.run_single_config()
  File "xla/benchmarks/experiment_runner.py", line 256, in run_single_config
    metrics, last_output = self.run_once_and_gather_metrics(
  File "xla/benchmarks/experiment_runner.py", line 345, in run_once_and_gather_metrics
    output, _ = loop(iter_fn=self._default_iter_fn)
  File "xla/benchmarks/experiment_runner.py", line 302, in loop
    output, timing, trace = iter_fn(benchmark_experiment, benchmark_model,
  File "xla/benchmarks/experiment_runner.py", line 218, in _default_iter_fn
    output = benchmark_model.model_iter_fn(
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 390, in _fn
    return fn(*args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 939, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 802, in _convert_frame
    result = inner_convert(
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 713, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 686, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 265, in time_wrapper
    r = func(*args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 541, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 1078, in transform_code_object
    transformations(instructions, code_options)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 503, in transform
    tracer.run()
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2202, in run
    super().run()
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 843, in run
    while self.step():
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 757, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 481, in wrapper
    return inner_fn(self, inst)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1265, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 697, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/variables/nn_module.py", line 349, in call_function
    return tx.inline_user_function_return(
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 703, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2362, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2476, in inline_call_
    tracer.run()
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 843, in run
    while self.step():
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 757, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 481, in wrapper
    return inner_fn(self, inst)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1265, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 697, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 341, in call_function
    return super().call_function(tx, args, kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 295, in call_function
    return super().call_function(tx, args, kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 703, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2362, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2476, in inline_call_
    tracer.run()
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 843, in run
    while self.step():
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 757, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 481, in wrapper
    return inner_fn(self, inst)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1277, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 697, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/variables/nn_module.py", line 349, in call_function
    return tx.inline_user_function_return(
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 703, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2362, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2476, in inline_call_
    tracer.run()
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 843, in run
    while self.step():
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 757, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 481, in wrapper
    return inner_fn(self, inst)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1265, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 697, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 341, in call_function
    return super().call_function(tx, args, kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 295, in call_function
    return super().call_function(tx, args, kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 703, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2362, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2476, in inline_call_
    tracer.run()
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 843, in run
    while self.step():
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 757, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1326, in LOAD_ATTR
    self._load_attr(inst)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1316, in _load_attr
    result = BuiltinVariable(getattr).call_function(
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/variables/builtin.py", line 939, in call_function
    return handler(tx, args, kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/variables/builtin.py", line 823, in builtin_dipatch
    rv = fn(tx, args, kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/variables/builtin.py", line 743, in call_self_handler
    result = self_handler(tx, *args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/variables/builtin.py", line 1519, in call_getattr
    return obj.var_getattr(tx, name)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/variables/base.py", line 230, in var_getattr
    value = self.const_getattr(tx, name)
  File "/root/.local/lib/python3.8/site-packages/torch/_dynamo/variables/constant.py", line 123, in const_getattr
    member = getattr(self.value, name)
torch._dynamo.exc.InternalTorchDynamoError: 'str' object has no attribute 'size'

from user code:
   File "/tmp/xla/benchmarks/benchmark_model.py", line 170, in eval
    pred = self.module(*inputs)
  File "/root/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 1360, in forward
    outputs = self.bert(
  File "/root/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 961, in forward
    input_shape = input_ids.size()

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
  1. non dynamo banckend:
ERROR:__main__:ERROR in subprocess
INFO:__main__:Run with --model-config={"model_name": "hf_Bert_large"} --experiment-config={"accelerator": "cuda", "xla": "PJRT", "xla_flags": null, "dynamo": null, "test": "eval"}
WARNING:__main__:Enabling fast F32 multiplication for PyTorch
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1711990872.886326    2060 service.cc:145] XLA service 0x55c4d7299290 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1711990872.886425    2060 service.cc:153]   StreamExecutor device (0): NVIDIA A100-SXM4-40GB, Compute Capability 8.0
I0000 00:00:1711990872.886901    2060 se_gpu_pjrt_client.cc:853] Using BFC allocator.
I0000 00:00:1711990872.886984    2060 gpu_helpers.cc:107] XLA backend allocating 31724126208 bytes on device 0 for BFCAllocator.
I0000 00:00:1711990872.887012    2060 gpu_helpers.cc:147] XLA backend will use up to 10574708736 bytes on device 0 for CollectiveBFCAllocator.
Traceback (most recent call last):
  File "xla/benchmarks/experiment_runner.py", line 945, in <module>
    main()
  File "xla/benchmarks/experiment_runner.py", line 941, in main
    runner.run()
  File "xla/benchmarks/experiment_runner.py", line 61, in run
    self.run_single_config()
  File "xla/benchmarks/experiment_runner.py", line 256, in run_single_config
    metrics, last_output = self.run_once_and_gather_metrics(
  File "xla/benchmarks/experiment_runner.py", line 345, in run_once_and_gather_metrics
    output, _ = loop(iter_fn=self._default_iter_fn)
  File "xla/benchmarks/experiment_runner.py", line 302, in loop
    output, timing, trace = iter_fn(benchmark_experiment, benchmark_model,
  File "xla/benchmarks/experiment_runner.py", line 218, in _default_iter_fn
    output = benchmark_model.model_iter_fn(
  File "/tmp/xla/benchmarks/benchmark_model.py", line 170, in eval
    pred = self.module(*inputs)
  File "/root/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 1360, in forward
    outputs = self.bert(
  File "/root/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 960, in forward
    self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  File "/usr/local/lib/python3.8/site-packages/transformers/modeling_utils.py", line 4169, in warn_if_padding_and_no_attention_mask
    if self.config.pad_token_id in input_ids[:, [-1, 0]]:
TypeError: string indices must be integers
@zpcore zpcore self-assigned this Apr 1, 2024
@JackCaoG
Copy link
Collaborator

JackCaoG commented Apr 1, 2024

are we using nightly HF as well?

@zpcore
Copy link
Collaborator Author

zpcore commented Apr 1, 2024

are we using nightly HF as well?

No, Torchbench is using transformers==4.38. We use the same config.

@zpcore
Copy link
Collaborator Author

zpcore commented Apr 1, 2024

The failure is related to this PR: #6792. Here is the chat from @JackCaoG :

@Jiewen Tan I think we need to think about how to support flash_attenion with dynamo. It seems like if user put your flash_attention wrapper(https://github.com/pytorch/xla/blob/master/torch_xla/experimental/custom_kernel.py#L147-L185) inside the torch.compile region and dynamo will try to step through the JAX python code and eventually failed. 

Based on your original pr https://github.com/pytorch/xla/pull/6477, do you expect user to call flash_attenion outside of torch.compile and extract the payload and use that in the torch.compile region?

Waiting for pending fix fro @alanwaketan.

@zpcore
Copy link
Collaborator Author

zpcore commented Apr 1, 2024

Sorry, double checked that the failure should be due to the changes in torchbench. Let me confirm which PR and will make an update.

@zpcore
Copy link
Collaborator Author

zpcore commented Apr 1, 2024

The issue is related to the changes in the torchbench upstream pytorch/benchmark#2197.

In torchbenchmark/util/framework/huggingface/model_factory.py,

def get_module(self):
        return self.model, self.example_inputs

return item self.example_inputs becomes a dict instead of the list of tensor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants