Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/unit/vertexai/genai/replays/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,32 @@ def test_evaluation_agent_data(client):
assert case_result.response_candidate_results is not None


def test_metric_resource_name(client):
"""Tests with a metric resource name in types.Metric."""
client._api_client._http_options.api_version = "v1beta1"
client._api_client._http_options.base_url = (
"https://us-central1-staging-aiplatform.sandbox.googleapis.com/"
)
metric_resource_name = "projects/977012026409/locations/us-central1/evaluationMetrics/6048334299558576128"
byor_df = pd.DataFrame(
{
"prompt": ["Write a simple story about a dinosaur"],
"response": ["Once upon a time, there was a T-Rex named Rexy."],
}
)
metric = types.Metric(
name="my_custom_metric", metric_resource_name=metric_resource_name
)
evaluation_result = client.evals.evaluate(
dataset=byor_df,
metrics=[metric],
)
assert isinstance(evaluation_result, types.EvaluationResult)
assert evaluation_result.eval_case_results is not None
assert len(evaluation_result.eval_case_results) > 0
assert evaluation_result.summary_metrics[0].metric_name == "my_custom_metric"


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
Expand Down
2 changes: 1 addition & 1 deletion vertexai/_genai/_evals_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,7 +1364,7 @@ def _resolve_evaluation_run_metrics(
raise
elif isinstance(metric_instance, types.Metric):
config_dict = t.t_metrics([metric_instance])[0]
res_name = config_dict.pop("metric_resource_name", None)
res_name = getattr(metric_instance, "metric_resource_name", None)
resolved_metrics_list.append(
types.EvaluationRunMetric(
metric=metric_instance.name,
Expand Down
175 changes: 157 additions & 18 deletions vertexai/_genai/_evals_metric_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import statistics
import time
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, Callable, Generic, Optional, TypeVar, Union

from google.genai import errors as genai_errors
from google.genai import _common
Expand All @@ -39,6 +39,9 @@
_MAX_RETRIES = 3


T = TypeVar("T", types.Metric, types.MetricSource, types.LLMMetric)


def _has_tool_call(intermediate_events: Optional[list[types.evals.Event]]) -> bool:
"""Checks if any event in intermediate_events has a function call."""
if not intermediate_events:
Expand Down Expand Up @@ -149,12 +152,18 @@ def _default_aggregate_scores(
)


class MetricHandler(abc.ABC):
class MetricHandler(abc.ABC, Generic[T]):
"""Abstract base class for metric handlers."""

def __init__(self, module: "evals.Evals", metric: types.Metric):
def __init__(self, module: "evals.Evals", metric: T):
self.module = module
self.metric = metric
self.metric: T = metric

@property
@abc.abstractmethod
def metric_name(self) -> str:
"""Returns the name of the metric polymorphically."""
raise NotImplementedError()

@abc.abstractmethod
def get_metric_result(
Expand All @@ -171,7 +180,7 @@ def aggregate(
raise NotImplementedError()


class ComputationMetricHandler(MetricHandler):
class ComputationMetricHandler(MetricHandler[types.Metric]):
"""Metric handler for computation metrics."""

SUPPORTED_COMPUTATION_METRICS = frozenset(
Expand All @@ -188,6 +197,10 @@ class ComputationMetricHandler(MetricHandler):
}
)

@property
def metric_name(self) -> str:
return self.metric.name or "unknown_metric"

def __init__(self, module: "evals.Evals", metric: types.Metric):
super().__init__(module=module, metric=metric)
if self.metric.name not in self.SUPPORTED_COMPUTATION_METRICS:
Expand Down Expand Up @@ -299,11 +312,15 @@ def aggregate(
return _default_aggregate_scores(self.metric.name, eval_case_metric_results)


class TranslationMetricHandler(MetricHandler):
class TranslationMetricHandler(MetricHandler[types.Metric]):
"""Metric handler for translation metrics."""

SUPPORTED_TRANSLATION_METRICS = frozenset({"comet", "metricx"})

@property
def metric_name(self) -> str:
return self.metric.name or "unknown_metric"

def __init__(self, module: "evals.Evals", metric: types.Metric):
super().__init__(module=module, metric=metric)

Expand Down Expand Up @@ -469,9 +486,13 @@ def aggregate(
return _default_aggregate_scores(self.metric.name, eval_case_metric_results)


class LLMMetricHandler(MetricHandler):
class LLMMetricHandler(MetricHandler[types.LLMMetric]):
"""Metric handler for LLM metrics."""

@property
def metric_name(self) -> str:
return self.metric.name or "unknown_metric"

def __init__(self, module: "evals.Evals", metric: types.LLMMetric):
super().__init__(module=module, metric=metric)

Expand Down Expand Up @@ -750,9 +771,13 @@ def aggregate(
return _default_aggregate_scores(self.metric.name, eval_case_metric_results)


class CustomMetricHandler(MetricHandler):
class CustomMetricHandler(MetricHandler[types.Metric]):
"""Metric handler for custom metrics."""

@property
def metric_name(self) -> str:
return self.metric.name or "unknown_metric"

def __init__(self, module: "evals.Evals", metric: types.Metric):
super().__init__(module=module, metric=metric)

Expand Down Expand Up @@ -853,9 +878,13 @@ def aggregate(
return _default_aggregate_scores(self.metric.name, eval_case_metric_results)


class PredefinedMetricHandler(MetricHandler):
class PredefinedMetricHandler(MetricHandler[types.Metric]):
"""Metric handler for predefined metrics."""

@property
def metric_name(self) -> str:
return self.metric.name or "unknown_metric"

def __init__(self, module: "evals.Evals", metric: types.Metric):
super().__init__(module=module, metric=metric)
if self.metric.name not in _evals_constant.SUPPORTED_PREDEFINED_METRICS:
Expand Down Expand Up @@ -1106,9 +1135,13 @@ def aggregate(
)


class CustomCodeExecutionMetricHandler(MetricHandler):
class CustomCodeExecutionMetricHandler(MetricHandler[types.Metric]):
"""Metric handler for custom code execution metrics."""

@property
def metric_name(self) -> str:
return self.metric.name or "unknown_metric"

def __init__(self, module: "evals.Evals", metric: types.Metric):
super().__init__(module=module, metric=metric)

Expand Down Expand Up @@ -1242,6 +1275,108 @@ def aggregate(
)


class RegisteredMetricHandler(MetricHandler[types.MetricSource]):
"""Metric handler for registered metrics."""

def __init__(
self,
module: "evals.Evals",
metric: Union[types.MetricSource, types.MetricSourceDict],
):
if isinstance(metric, dict):
metric = types.MetricSource(**metric)
super().__init__(module=module, metric=metric)

# TODO: b/489823454 - Unify _build_request_payload with PredefinedMetricHandler.
def _build_request_payload(
self, eval_case: types.EvalCase, response_index: int
) -> dict[str, Any]:
"""Builds request payload for registered metric."""
if not self.metric.metric:
raise ValueError(
"Registered metric must have an underlying metric definition."
)
return PredefinedMetricHandler(
self.module, metric=self.metric.metric
)._build_request_payload(eval_case, response_index)

@property
def metric_name(self) -> str:
# Resolve name from resource name or internal metric name
if isinstance(self.metric, types.MetricSource):
if self.metric.metric and self.metric.metric.name:
return self.metric.metric.name
if self.metric.metric_resource_name:
return self.metric.metric_resource_name
return "unknown"
else: # Should be Metric
metric_like = self.metric
if metric_like.name:
return metric_like.name
if metric_like.metric_resource_name:
return metric_like.metric_resource_name
return "unknown"

@override
def get_metric_result(
self, eval_case: types.EvalCase, response_index: int
) -> types.EvalCaseMetricResult:
"""Processes a single evaluation case for a registered metric."""
metric_name = self.metric_name
try:
payload = self._build_request_payload(eval_case, response_index)
for attempt in range(_MAX_RETRIES):
try:
api_response = self.module._evaluate_instances(
metric_sources=[self.metric],
instance=payload.get("instance"),
autorater_config=payload.get("autorater_config"),
)
break
except genai_errors.ClientError as e:
if e.code == 429:
if attempt == _MAX_RETRIES - 1:
return types.EvalCaseMetricResult(
metric_name=metric_name,
error_message=f"Judge model resource exhausted after {_MAX_RETRIES} retries: {e}",
)
time.sleep(2**attempt)
else:
raise e

if api_response and api_response.metric_results:
result_data = api_response.metric_results[0]
error_message = None
if result_data.error and getattr(result_data.error, "code"):
error_message = f"Error in metric result: {result_data.error}"
return types.EvalCaseMetricResult(
metric_name=metric_name,
score=result_data.score,
explanation=result_data.explanation,
rubric_verdicts=result_data.rubric_verdicts,
error_message=error_message,
)
else:
return types.EvalCaseMetricResult(
metric_name=metric_name,
error_message="Metric results missing in API response.",
)
except Exception as e:
return types.EvalCaseMetricResult(
metric_name=metric_name, error_message=str(e)
)

@override
def aggregate(
self, eval_case_metric_results: list[types.EvalCaseMetricResult]
) -> types.AggregatedMetricResult:
"""Aggregates the metric results for a registered metric."""
logger.debug("Aggregating results for registered metric: %s", self.metric_name)
return _default_aggregate_scores(
self.metric_name, eval_case_metric_results, calculate_pass_rate=True
)


_METRIC_HANDLER_MAPPING = [
(
lambda m: hasattr(m, "remote_custom_function") and m.remote_custom_function,
Expand All @@ -1251,6 +1386,10 @@ def aggregate(
lambda m: m.custom_function and isinstance(m.custom_function, Callable),
CustomMetricHandler,
),
(
lambda m: getattr(m, "metric_resource_name", None) is not None,
RegisteredMetricHandler,
),
(
lambda m: m.name in ComputationMetricHandler.SUPPORTED_COMPUTATION_METRICS,
ComputationMetricHandler,
Expand Down Expand Up @@ -1337,14 +1476,14 @@ def calculate_win_rates(eval_result: types.EvaluationResult) -> dict[str, Any]:


def _aggregate_metric_results(
metric_handlers: list[MetricHandler],
metric_handlers: list[MetricHandler[Any]],
eval_case_results: list[types.EvalCaseResult],
) -> list[types.AggregatedMetricResult]:
"""Aggregates results by calling the aggregate method of each handler."""
aggregated_metric_results = []
logger.info("Aggregating results per metric...")
for handler in metric_handlers:
metric_name = handler.metric.name
metric_name = handler.metric_name
results_for_this_metric: list[types.EvalCaseMetricResult] = []
for case_result in eval_case_results:
if case_result.response_candidate_results:
Expand Down Expand Up @@ -1473,12 +1612,12 @@ def compute_metrics_and_aggregate(
"response %d for metric %s.",
eval_case_index,
response_index,
metric_handler_instance.metric.name,
metric_handler_instance.metric_name,
)
all_futures.append(
(
future,
metric_handler_instance.metric.name,
metric_handler_instance.metric_name,
eval_case_index,
response_index,
)
Expand All @@ -1489,25 +1628,25 @@ def compute_metrics_and_aggregate(
"response %d for metric %s: %s",
eval_case_index,
response_index,
metric_handler_instance.metric.name,
metric_handler_instance.metric_name,
e,
exc_info=True,
)
submission_errors.append(
(
metric_handler_instance.metric.name,
metric_handler_instance.metric_name,
eval_case_index,
response_index,
f"Error: {e}",
)
)
error_result = types.EvalCaseMetricResult(
metric_name=metric_handler_instance.metric.name,
metric_name=metric_handler_instance.metric_name,
error_message=f"Submission Error: {e}",
)
results_by_case_response_metric[eval_case_index][
response_index
][metric_handler_instance.metric.name] = error_result
][metric_handler_instance.metric_name] = error_result
case_indices_with_errors.add(eval_case_index)
pbar.update(1)

Expand Down
Loading
Loading