Skip to content

Commit

Permalink
feat: Add retry_timeout to EvalTask in vertexai.preview.evaluation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 644169601
  • Loading branch information
jsondai authored and copybara-github committed Jun 17, 2024
1 parent 28a091a commit 4d9ee9d
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 2 deletions.
2 changes: 2 additions & 0 deletions vertexai/preview/evaluation/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ class EvaluationRunConfig:
CustomMetric instances, or PairwiseMetric instances to evaluate.
column_map: The dictionary of column name overrides in the dataset.
client: The asynchronous evaluation client.
retry_timeout: How long to keep retrying the evaluation requests, in seconds.
"""

dataset: "pd.DataFrame"
metrics: List[Union[str, metrics_base.CustomMetric, metrics_base.PairwiseMetric]]
column_map: Dict[str, str]
client: gapic_evaluation_services.EvaluationServiceAsyncClient
retry_timeout: float

def validate_dataset_column(self, column_name: str) -> None:
"""Validates that the column names in the column map are in the dataset.
Expand Down
4 changes: 4 additions & 0 deletions vertexai/preview/evaluation/_eval_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def evaluate(
prompt_template: Optional[str] = None,
experiment_run_name: Optional[str] = None,
response_column_name: Optional[str] = None,
retry_timeout: float = 600.0,
) -> EvalResult:
"""Runs an evaluation for the EvalTask.
Expand All @@ -318,6 +319,8 @@ def evaluate(
unique experiment run name is used.
response_column_name: The column name of model response in the dataset. If
provided, this will override the `response_column_name` of the `EvalTask`.
retry_timeout: How long to keep retrying the evaluation requests for
the whole evaluation dataset, in seconds.
Returns:
The evaluation result.
Expand Down Expand Up @@ -364,6 +367,7 @@ def evaluate(
content_column_name=self.content_column_name,
reference_column_name=self.reference_column_name,
response_column_name=response_column_name,
retry_timeout=retry_timeout,
)
return eval_result

Expand Down
6 changes: 5 additions & 1 deletion vertexai/preview/evaluation/_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ async def _compute_metrics(
row_dict=row_dict,
evaluation_run_config=evaluation_run_config,
),
retry_timeout=evaluation_run_config.retry_timeout,
)
)
if isinstance(metric, metrics_base.PairwiseMetric):
Expand Down Expand Up @@ -618,6 +619,7 @@ def evaluate(
response_column_name: str = "response",
context_column_name: str = "context",
instruction_column_name: str = "instruction",
retry_timeout: float = 600.0,
) -> evaluation_base.EvalResult:
"""Runs the evaluation for metrics.
Expand All @@ -644,7 +646,8 @@ def evaluate(
not set, default to `context`.
instruction_column_name: The column name of the instruction prompt in the
dataset. If not set, default to `instruction`.
retry_timeout: How long to keep retrying the evaluation requests for the
whole evaluation dataset, in seconds.
Returns:
EvalResult with summary metrics and a metrics table for per-instance
metrics.
Expand All @@ -670,6 +673,7 @@ def evaluate(
constants.Dataset.INSTRUCTION_COLUMN: instruction_column_name,
},
client=utils.create_evaluation_service_async_client(),
retry_timeout=retry_timeout,
)

if set(evaluation_run_config.metrics).intersection(
Expand Down
4 changes: 3 additions & 1 deletion vertexai/preview/evaluation/metrics/_instance_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,12 +609,14 @@ def _handle_response(
async def evaluate_instances_async(
client: gapic_evaluation_services.EvaluationServiceAsyncClient,
request: gapic_eval_service_types.EvaluateInstancesRequest,
retry_timeout: float,
):
"""Evaluates an instance asynchronously.
Args:
client: The client to use for evaluation.
request: An EvaluateInstancesRequest.
retry_timeout: How long to keep retrying the evaluation requests, in seconds.
Returns:
The metric score of the evaluation.
Expand All @@ -626,7 +628,7 @@ async def evaluate_instances_async(
initial=0.250,
maximum=90.0,
multiplier=1.45,
deadline=600.0,
timeout=retry_timeout,
predicate=api_core.retry.if_exception_type(
api_core.exceptions.Aborted,
api_core.exceptions.DeadlineExceeded,
Expand Down

0 comments on commit 4d9ee9d

Please sign in to comment.