Source code for trinity.utils.metrics

"""Unified metrics aggregation utilities for Trinity-RFT.

Metric keys may carry an aggregation-type suffix in the form ``name:agg``.
Supported suffixes: ``:mean``, ``:sum``, ``:max``, ``:min``, ``:last``.
Keys without a suffix default to mean aggregation.
"""

from collections import defaultdict
from enum import Enum
from typing import Any, Callable, Dict, List, Tuple

import numpy as np


[docs] class AggType(str, Enum): MEAN = "mean" SUM = "sum" MAX = "max" MIN = "min" LAST = "last"
[docs] def take_last(values: List[float]) -> float: return values[-1]
AGG_REDUCERS: Dict[AggType, Callable[[List[float]], float]] = { AggType.SUM: sum, AggType.MAX: max, AggType.MIN: min, AggType.LAST: take_last, } STAT_FNS: Dict[str, Callable[[np.ndarray], np.generic]] = { "mean": np.mean, "max": np.max, "min": np.min, "std": np.std, }
[docs] def group_numeric_metrics( metric_dicts: List[Dict[str, float]] ) -> Dict[Tuple[str, AggType], List[float]]: grouped: Dict[Tuple[str, AggType], List[float]] = defaultdict(list) for metric_dict in metric_dicts: for key, value in metric_dict.items(): if not isinstance(value, (int, float)): continue name, agg = parse_metric_key(key) grouped[(name, agg)].append(float(value)) return grouped
[docs] def group_metrics_by_canonical_key( metric_dicts: List[Dict[str, float]], ) -> Dict[str, Tuple[AggType, List[float]]]: grouped: Dict[str, Tuple[AggType, List[float]]] = {} for metric_dict in metric_dicts: for key, value in metric_dict.items(): if not isinstance(value, (int, float)): continue name, agg = parse_metric_key(key) canonical_key = f"{name}:{agg.value}" if agg != AggType.MEAN else name if canonical_key not in grouped: grouped[canonical_key] = (agg, []) grouped[canonical_key][1].append(float(value)) return grouped
[docs] def parse_metric_key(key: str) -> Tuple[str, AggType]: """Parse a metric key into (name, aggregation_type). Examples: "reward" -> ("reward", AggType.MEAN) "experience_count:sum" -> ("experience_count", AggType.SUM) "model_version:last" -> ("model_version", AggType.LAST) "some:unknown_suffix" -> ("some:unknown_suffix", AggType.MEAN) """ if ":" in key: name, agg_str = key.rsplit(":", 1) try: return name, AggType(agg_str) except ValueError: return key, AggType.MEAN return key, AggType.MEAN
[docs] def aggregate_metrics( metric_dicts: List[Dict[str, float]], prefix: str = "", default_output_stats: List[str] | None = None, ) -> Dict[str, float]: """Aggregate a list of metric dictionaries respecting per-key aggregation types. For keys with AggType.MEAN, outputs ``{prefix}/{name}/mean``, ``/max``, ``/min`` (controlled by *default_output_stats*). For AggType.SUM, outputs ``{prefix}/{name}/sum``. For AggType.MAX, outputs ``{prefix}/{name}/max``. For AggType.MIN, outputs ``{prefix}/{name}/min``. For AggType.LAST, outputs ``{prefix}/{name}/last``. Args: metric_dicts: List of flat metric dictionaries (values must be numeric). prefix: Optional prefix prepended as ``{prefix}/{name}/...``. default_output_stats: Stats to output for MEAN metrics. Defaults to ["mean", "max", "min"]. Returns: Flat dictionary of aggregated metrics ready for monitor logging. """ if not metric_dicts: return {} if default_output_stats is None: default_output_stats = ["mean", "max", "min"] grouped = group_numeric_metrics(metric_dicts) result: Dict[str, float] = {} prefix_str = f"{prefix}/" if prefix else "" for (name, agg), values in grouped.items(): if agg == AggType.MEAN: arr = np.array(values) for stat in default_output_stats: fn = STAT_FNS.get(stat) if fn is not None: result[f"{prefix_str}{name}/{stat}"] = fn(arr).item() continue reducer = AGG_REDUCERS.get(agg) if reducer is not None: result[f"{prefix_str}{name}/{agg.value}"] = reducer(values) return result
[docs] def aggregate_eval_metrics( metric_dicts: List[Dict[str, float]], prefix: str = "", output_stats: List[str] | None = None, detailed_stats: bool = False, ) -> Dict[str, float]: """Aggregate eval metrics with optional detailed statistics. For MEAN metrics: - If detailed_stats=True: output mean/max/min/std per the output_stats list. - If detailed_stats=False: output only the mean value as ``{prefix}/{name}``. For non-MEAN metrics: same behavior as aggregate_metrics. """ if not metric_dicts: return {} if output_stats is None: output_stats = ["mean", "max", "min", "std"] grouped = group_numeric_metrics(metric_dicts) result: Dict[str, float] = {} prefix_str = f"{prefix}/" if prefix else "" for (name, agg), values in grouped.items(): if agg == AggType.MEAN: arr = np.array(values) if detailed_stats: for stat in output_stats: fn = STAT_FNS.get(stat) if fn is not None: result[f"{prefix_str}{name}/{stat}"] = fn(arr).item() else: result[f"{prefix_str}{name}"] = np.mean(arr).item() continue reducer = AGG_REDUCERS.get(agg) if reducer is not None: result[f"{prefix_str}{name}/{agg.value}"] = reducer(values) return result
[docs] def aggregate_run_level_metrics(metric_dicts: List[Dict[str, float]]) -> Dict[str, float]: """Aggregate experience-level metrics into a single run-level metric dict. Unlike batch-level aggregation, this preserves the original key format (with ``:agg`` suffix if present) so that downstream task/batch aggregation can still see the aggregation type annotation. Aggregation rules: - MEAN keys: averaged across experiences - SUM keys: summed across experiences - MAX keys: max across experiences - MIN keys: min across experiences - LAST keys: last value """ if not metric_dicts: return {} grouped = group_metrics_by_canonical_key(metric_dicts) result: Dict[str, float] = {} for key, (agg, values) in grouped.items(): if agg == AggType.MEAN: result[key] = sum(values) / len(values) continue reducer = AGG_REDUCERS.get(agg) if reducer is not None: result[key] = reducer(values) return result
# adapted from verl/trainer/ppo/metric_utils.py
[docs] def bootstrap_metric( data: List[Any], subset_size: int, reduce_fns: List[Callable[[List[Any]], float]], n_bootstrap: int = 1000, seed: int = 42, ) -> List[Tuple[float, float]]: """Estimate metric statistics with bootstrap resampling.""" np.random.seed(seed) bootstrap_metric_lists = [[] for _ in range(len(reduce_fns))] for _ in range(n_bootstrap): bootstrap_idxs = np.random.choice(len(data), size=subset_size, replace=True) bootstrap_data = [data[i] for i in bootstrap_idxs] for i, reduce_fn in enumerate(reduce_fns): bootstrap_metric_lists[i].append(reduce_fn(bootstrap_data)) return [(float(np.mean(lst)), float(np.std(lst))) for lst in bootstrap_metric_lists]
[docs] def calculate_task_level_metrics( metrics: List[Dict[str, float]], is_eval: bool ) -> Dict[str, float]: """Calculate task-level metrics from multiple runs of the same task.""" if not metrics: return {} grouped = group_metrics_by_canonical_key(metrics) if not is_eval: return aggregate_run_level_metrics(metrics) result: Dict[str, float] = {} for key, (agg, values) in grouped.items(): name, _ = parse_metric_key(key) if agg != AggType.MEAN: reducer = AGG_REDUCERS.get(agg) if reducer is not None: result[f"{name}:{agg.value}"] = reducer(values) continue if "time/task_execution" in name or "time/run_execution" in name: result[key] = sum(values) / len(values) continue n_values = len(values) result[f"{key}/mean@{n_values}"] = float(np.mean(values)) result[f"{key}/std@{n_values}"] = float(np.std(values)) if n_values <= 1: continue ns = [] n = 2 while n < n_values: ns.append(n) n *= 2 ns.append(n_values) for subset_size in ns: [(best_mean, _), (worst_mean, _)] = bootstrap_metric( data=values, subset_size=subset_size, reduce_fns=[max, min], seed=42, ) result[f"{key}/best@{subset_size}"] = best_mean result[f"{key}/worst@{subset_size}"] = worst_mean return result