Skip to content

Commit

Permalink
feat: Add progress bar to custom metrics.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 663810771
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Aug 16, 2024
1 parent 11a39e3 commit 3974aec
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions vertexai/preview/evaluation/_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,14 @@ def _replace_metric_bundle_with_metrics(
def _compute_custom_metrics(
row_dict: Dict[str, Any],
custom_metrics: List[metrics_base.CustomMetric],
pbar: tqdm,
) -> Dict[str, Any]:
"""Computes custom metrics for a row.
Args:
row_dict: A dictionary of an instance in the eval dataset.
custom_metrics: A list of CustomMetrics.
pbar: A tqdm progress bar.
Returns:
A dictionary of an instance containing custom metric results.
Expand All @@ -178,6 +180,7 @@ def _compute_custom_metrics(
"""
for custom_metric in custom_metrics:
metric_output = custom_metric.metric_function(row_dict)
pbar.update(1)
if custom_metric.name in metric_output:
row_dict[custom_metric.name] = metric_output[custom_metric.name]
else:
Expand Down Expand Up @@ -613,6 +616,9 @@ def _compute_metrics(
)
row_count = len(evaluation_run_config.dataset)
api_request_count = len(api_metrics) * row_count
custom_metric_request_count = len(custom_metrics) * row_count
total_request_count = api_request_count + custom_metric_request_count

_LOGGER.info(
f"Computing metrics with a total of {api_request_count} Vertex online"
" evaluation service requests."
Expand All @@ -622,10 +628,10 @@ def _compute_metrics(
futures_by_metric = collections.defaultdict(list)

rate_limiter = utils.RateLimiter(evaluation_run_config.evaluation_service_qps)
with tqdm(total=api_request_count) as pbar:
with tqdm(total=total_request_count) as pbar:
with futures.ThreadPoolExecutor(max_workers=constants.MAX_WORKERS) as executor:
for idx, row in evaluation_run_config.dataset.iterrows():
row_dict = _compute_custom_metrics(row.to_dict(), custom_metrics)
row_dict = _compute_custom_metrics(row.to_dict(), custom_metrics, pbar)

instance_list.append(row_dict)

Expand Down

0 comments on commit 3974aec

Please sign in to comment.