Skip to content

Commit

Permalink
Merge branch 'master' into feature/upgrade_to_cu118
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim authored Jul 22, 2023
2 parents 63fe99e + 48fe78f commit c8da525
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/performance_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ You can find more information on TorchServe benchmarking [here](https://github.c

TorchServe has native support for the PyTorch profiler which will help you find performance bottlenecks in your code.

If you created a custom `handle` or `initialize` method overwriting the BaseHandler, you must define the `self.manifest` attribute to be able to run `_infer_with_profiler`.

```
export ENABLE_TORCH_PROFILER=TRUE
```
Expand Down
3 changes: 3 additions & 0 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ def handle(self, data, context):
is_profiler_enabled = os.environ.get("ENABLE_TORCH_PROFILER", None)
if is_profiler_enabled:
if PROFILER_AVAILABLE:
if self.manifest is None:
# profiler will use to get the model name
self.manifest = context.manifest
output, _ = self._infer_with_profiler(data=data)
else:
raise RuntimeError(
Expand Down
13 changes: 12 additions & 1 deletion ts/torch_handler/unit_tests/test_base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
Ensures it can load and execute an example model
"""

import os
import pytest

from ts.torch_handler.base_handler import BaseHandler
from ts.torch_handler.base_handler import BaseHandler, PROFILER_AVAILABLE


@pytest.fixture()
Expand All @@ -30,3 +31,13 @@ def test_batch_handle(handler, base_model_context):
processed = handler.handle(list_data, base_model_context)

assert processed == [1, 0]


def test_inference_with_profiler_works_with_custom_initialize_method(handler, base_model_context):
handler.manifest = None
PROFILER_AVAILABLE = True
os.environ["ENABLE_TORCH_PROFILER"] = "1"

list_data = [[1.0, 2.0], [4.0, 3.0]]
processed = handler.handle(list_data, base_model_context)
assert processed == [1, 0]

0 comments on commit c8da525

Please sign in to comment.