Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gha: test multiple versions #30

Merged
merged 11 commits into from
Jul 3, 2024
28 changes: 22 additions & 6 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ on:

env:
FORCE_COLOR: "1"
# facilitate testing by building vLLM for CPU when needed
VLLM_CPU_DISABLE_AVX512: "true"
VLLM_TARGET_DEVICE: "cpu"
# prefer torch cpu version
PIP_EXTRA_INDEX_URL: "https://download.pytorch.org/whl/cpu"

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
Expand All @@ -25,6 +30,11 @@ jobs:
matrix:
os: [ubuntu-latest]
pyv: ["3.11"]
vllm_version:
# - "pypi" # skip the pypi version as it will not work with CPU
dtrifiro marked this conversation as resolved.
Show resolved Hide resolved
- "git+https://github.com/vllm-project/vllm@v0.5.0.post1"
- "git+https://github.com/vllm-project/vllm@main"
- "git+https://github.com/opendatahub-io/vllm@main"

steps:
- name: Check out the repository
Expand All @@ -50,6 +60,14 @@ jobs:
pip --version
nox --version

- name: Set custom vllm version
if: ${{ matrix.vllm_version != 'pypi' }}
run: |
vllm_version="vllm@${{matrix.vllm_version}}"
echo "Using vllm@${vllm_version}"

sed -i "s|\"vllm.*\",|\"${vllm_version}\",|g" pyproject.toml

- name: Lint code and check dependencies
run: nox -s lint-${{ matrix.pyv }}

Expand All @@ -58,12 +76,10 @@ jobs:

- name: Upload coverage report
uses: codecov/codecov-action@v4
with:
fail_ci_if_error: true
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

- name: Build package
run: nox -s build-${{ matrix.pyv }}

- name: Upload artifact
uses: actions/upload-artifact@v4
with:
name: vllm-tgis-wheel
path: dist/vllm_tgis_adapter*.whl
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,17 @@ nox -s tests-3.10 # run tests session for a specific python version
nox -s build-3.11 # build the wheel package
nox -s lint-3.11 -- --mypy # run linting with type checks
```

### Testing without a GPU

The standard vllm built requires an Nvidia GPU. When this is not available, it is possible to compile `vllm` from source with CPU support:

```bash

env \
VLLM_CPU_DISABLE_AVX512=true VLLM_TARGET_DEVICE=cpu \
PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu \
pip install git+https://github.com/vllm-project/vllm
```

making it possible to run the tests on most hardware. Please note that the `pip` extra index url is required in order to install the torch CPU version.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ select = ["ALL"]
]
"src/vllm_tgis_adapter/healthcheck.py" = ["T201"]
"src/vllm_tgis_adapter/_version.py" = ["ALL"]
"tests/**" = ["S", "ARG001", "ARG002", "ANN"]
"tests/**" = ["S", "ARG001", "ARG002", "ANN", "PT019", "FBT001", "FBT002"]
"tests/utils.py" = ["T201"]
"setup.py" = [
"T201", # print() use
"S603", # subprocess call: check for execution of untrusted input
Expand Down
35 changes: 18 additions & 17 deletions src/vllm_tgis_adapter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,6 @@
_running_tasks: set[asyncio.Task] = set()


@asynccontextmanager
async def lifespan(app: fastapi.FastAPI) -> AsyncGenerator: # noqa: ARG001
async def _force_log(): # noqa: ANN202
while True:
await asyncio.sleep(10)
await engine.do_log_stats()

if not engine_args.disable_log_stats:
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)

yield


router = fastapi.APIRouter()

# Add prometheus asgi middleware to route /metrics requests
Expand Down Expand Up @@ -131,7 +116,23 @@ async def create_embedding(request: EmbeddingRequest, raw_request: fastapi.Reque
return JSONResponse(content=generator.model_dump())


def build_app(args: argparse.Namespace) -> fastapi.FastAPI:
def build_app( # noqa: C901 # FIXME: waiting on https://github.com/vllm-project/vllm/pull/5090 to get rid of this
engine: AsyncLLMEngine, args: argparse.Namespace
) -> fastapi.FastAPI:
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI) -> AsyncGenerator: # noqa: ARG001
async def _force_log(): # noqa: ANN202
while True:
await asyncio.sleep(10)
await engine.do_log_stats()

if not args.disable_log_stats:
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)

yield

app = fastapi.FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path
Expand Down Expand Up @@ -182,7 +183,7 @@ async def run_http_server(
args: argparse.Namespace,
model_config: ModelConfig,
) -> None:
app = build_app(args)
app = build_app(engine, args)

if args.served_model_name is not None:
served_model_names = args.served_model_name
Expand Down
43 changes: 23 additions & 20 deletions src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,22 @@ async def post_init(self) -> None:
assert self.tokenizer is not None

# Swap in the special TGIS stats logger
assert hasattr(self.engine.engine, "stat_logger")
assert self.engine.engine.stat_logger

vllm_stat_logger = self.engine.engine.stat_logger
tgis_stats_logger = TGISStatLogger(
vllm_stat_logger=vllm_stat_logger,
max_sequence_len=self.config.max_model_len,
)
# 🌶️🌶️🌶️ sneaky sneak
self.engine.engine.stat_logger = tgis_stats_logger
if hasattr(self.engine.engine, "stat_logger"):
# vllm <=0.5.1
tgis_stats_logger = TGISStatLogger(
vllm_stat_logger=self.engine.engine.stat_logger,
max_sequence_len=self.config.max_model_len,
)
self.engine.engine.stat_logger = tgis_stats_logger
elif hasattr(self.engine.engine, "stat_loggers"):
# vllm>=0.5.2
tgis_stats_logger = TGISStatLogger(
vllm_stat_logger=self.engine.engine.stat_loggers["prometheus"],
max_sequence_len=self.config.max_model_len,
)
self.engine.engine.stat_loggers["prometheus"] = tgis_stats_logger
else:
raise ValueError("engine doesn't have any known loggers.")

self.health_servicer.set(
self.SERVICE_NAME,
Expand Down Expand Up @@ -844,7 +850,10 @@ async def start_grpc_server(
assert isinstance(engine, AsyncLLMEngine)
assert isinstance(engine.engine, _AsyncLLMEngine)

logger.info(memory_summary(engine.engine.device_config.device))
if (device_type := engine.engine.device_config.device.type) == "cuda":
logger.info(memory_summary(engine.engine.device_config.device))
else:
logger.warning("Cannot print device usage for device type: %s", device_type)

server = aio.server()

Expand Down Expand Up @@ -913,21 +922,15 @@ async def run_grpc_server(
*,
disable_log_stats: bool,
) -> None:
async def _force_log() -> None:
while True:
await asyncio.sleep(10)
await engine.do_log_stats()

if not disable_log_stats:
asyncio.create_task(_force_log()) # noqa: RUF006

assert args is not None

server = await start_grpc_server(engine, args)

try:
while True:
await asyncio.sleep(60)
await asyncio.sleep(10)
if not disable_log_stats:
await engine.do_log_stats()
except asyncio.CancelledError:
print("Gracefully stopping gRPC server") # noqa: T201
await server.stop(30) # TODO configurable grace
Expand Down
12 changes: 9 additions & 3 deletions src/vllm_tgis_adapter/tgis_utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

from prometheus_client import Counter, Gauge, Histogram
from vllm import RequestOutput
from vllm.engine.metrics import StatLogger, Stats
from vllm.engine.metrics import Stats

try:
from vllm.engine.metrics import StatLoggerBase
except ImportError:
# vllm<=0.5.1
from vllm.engine.metrics import StatLogger as StatLoggerBase

from vllm_tgis_adapter.grpc.pb.generation_pb2 import (
BatchedTokenizeRequest,
Expand Down Expand Up @@ -102,10 +108,10 @@ def observe_generation_success(self, start_time: float) -> None:
self.tgi_request_duration.observe(duration)


class TGISStatLogger(StatLogger):
class TGISStatLogger(StatLoggerBase):
"""Wraps the vLLM StatLogger to report TGIS metric names for compatibility."""

def __init__(self, vllm_stat_logger: StatLogger, max_sequence_len: int):
def __init__(self, vllm_stat_logger: StatLoggerBase, max_sequence_len: int):
# Not calling super-init because we're wrapping and delegating to
# vllm_stat_logger
self._vllm_stat_logger = vllm_stat_logger
Expand Down
Loading