Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds serverless endpoints back #445

Merged
merged 5 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/package_reference/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
## Endpoints-based Models
### InferenceEndpointModel
[[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModelConfig
[[autodoc]] models.endpoints.endpoint_model.InferenceModelConfig
[[autodoc]] models.endpoints.endpoint_model.ServerlessEndpointModelConfig
[[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModel

### TGI ModelClient
Expand Down
clefourrier marked this conversation as resolved.
Show resolved Hide resolved
File renamed without changes.
19 changes: 12 additions & 7 deletions src/lighteval/main_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ def inference_endpoint(
str, Argument(help="Path to model config yaml file. (examples/model_configs/endpoint_model.yaml)")
],
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
free_endpoint: Annotated[
bool,
Option(
help="Use serverless free endpoints instead of spinning up your own inference endpoint.",
rich_help_panel=HELP_PANEL_NAME_4,
),
] = False,
# === Common parameters ===
use_chat_template: Annotated[
bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANEL_NAME_4)
Expand Down Expand Up @@ -200,9 +207,7 @@ def inference_endpoint(
"""

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.endpoints.endpoint_model import (
InferenceEndpointModelConfig,
)
from lighteval.models.endpoints.endpoint_model import InferenceEndpointModelConfig, ServerlessEndpointModelConfig
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
Expand All @@ -220,10 +225,10 @@ def inference_endpoint(
parallelism_manager = ParallelismManager.NONE # since we're using inference endpoints in remote

# Find a way to add this back
# if config["base_params"].get("endpoint_name", None):
# return InferenceModelConfig(model=config["base_params"]["endpoint_name"])

model_config = InferenceEndpointModelConfig.from_path(model_config_path)
if free_endpoint:
model_config = ServerlessEndpointModelConfig.from_path(model_config_path)
else:
model_config = InferenceEndpointModelConfig.from_path(model_config_path)

pipeline_params = PipelineParameters(
launcher_type=parallelism_manager,
Expand Down
29 changes: 21 additions & 8 deletions src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,18 @@


@dataclass
class InferenceModelConfig:
model: str
class ServerlessEndpointModelConfig:
model_name: str
add_special_tokens: bool = True

@classmethod
def from_path(cls, path: str) -> "ServerlessEndpointModelConfig":
import yaml

with open(path, "r") as f:
config = yaml.safe_load(f)["model"]
return cls(**config["base_params"])


@dataclass
class InferenceEndpointModelConfig:
Expand Down Expand Up @@ -150,7 +158,7 @@ class InferenceEndpointModel(LightevalModel):
"""

def __init__( # noqa: C901
self, config: Union[InferenceEndpointModelConfig, InferenceModelConfig], env_config: EnvConfig
self, config: Union[InferenceEndpointModelConfig, ServerlessEndpointModelConfig], env_config: EnvConfig
) -> None:
self.reuse_existing = getattr(config, "reuse_existing", False)
self._max_length = None
Expand Down Expand Up @@ -280,10 +288,10 @@ def __init__( # noqa: C901
else: # Free inference client
self.endpoint = None
self.endpoint_name = None
self.name = config.model
self.name = config.model_name
self.revision = "default"
self.async_client = AsyncInferenceClient(model=config.model, token=env_config.token)
self.client = InferenceClient(model=config.model, token=env_config.token)
self.async_client = AsyncInferenceClient(model=config.model_name, token=env_config.token)
self.client = InferenceClient(model=config.model_name, token=env_config.token)

self.use_async = True # set to False for debug - async use is faster

Expand All @@ -293,7 +301,7 @@ def __init__( # noqa: C901
self.model_info = ModelInfo(
model_name=self.name,
model_sha=self.revision,
model_dtype=config.model_dtype or "default",
model_dtype=getattr(config, "model_dtype", "default"),
model_size=-1,
)

Expand Down Expand Up @@ -545,7 +553,12 @@ def loglikelihood(
cont_toks = torch.tensor(cur_request.tokenized_continuation)
len_choice = len(cont_toks)

logits = [t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None]
if self.endpoint: # inference endpoint
logits = [
t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None
] # to check
else: # serverless endpoint
logits = [t.logprob for t in response.details.tokens[-len_choice:] if t.logprob is not None]

greedy_tokens = torch.tensor(logits).argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all().squeeze(0)
Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from lighteval.models.endpoints.endpoint_model import (
InferenceEndpointModel,
InferenceEndpointModelConfig,
InferenceModelConfig,
ServerlessEndpointModelConfig,
)
from lighteval.models.endpoints.openai_model import OpenAIClient, OpenAIModelConfig
from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig
Expand Down Expand Up @@ -80,7 +80,7 @@ def load_model( # noqa: C901
if isinstance(config, TGIModelConfig):
return load_model_with_tgi(config)

if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, InferenceModelConfig):
if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, ServerlessEndpointModelConfig):
return load_model_with_inference_endpoints(config, env_config=env_config)

if isinstance(config, BaseModelConfig):
Expand Down
2 changes: 1 addition & 1 deletion tests/models/endpoints/test_endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class TestInferenceEndpointModelConfig:
},
),
(
"examples/model_configs/endpoint_model_lite.yaml",
"examples/model_configs/serverless_model.yaml",
{
"model_name": "meta-llama/Llama-3.1-8B-Instruct",
# Defaults:
Expand Down
Loading