Source code for trinity.perf.report_metrics
from __future__ import annotations
from typing import Any, Optional
EXPERIENCE_COUNT_METRIC_KEY = "experience_pipeline/experience_count"
PROMPT_LENGTH_MEAN_METRIC_KEY = "rollout/prompt_length/mean"
RESPONSE_LENGTH_MEAN_METRIC_KEY = "rollout/response_length/mean"
API_CALL_PROMPT_TOKENS_PER_SECOND_MEAN_METRIC_KEY = "rollout/api_call_prompt_tokens_per_second/mean"
API_CALL_RESPONSE_TOKENS_PER_SECOND_MEAN_METRIC_KEY = (
"rollout/api_call_response_tokens_per_second/mean"
)
[docs]
def compute_global_token_throughput_metrics(
execution_time_sec: Optional[float], step_metrics: list[dict[str, Any]]
) -> dict[str, float | None]:
api_call_prompt_tokens_per_second_values = [
float(step_metric[API_CALL_PROMPT_TOKENS_PER_SECOND_MEAN_METRIC_KEY])
for step_metric in step_metrics
if step_metric.get(API_CALL_PROMPT_TOKENS_PER_SECOND_MEAN_METRIC_KEY) is not None
]
api_call_response_tokens_per_second_values = [
float(step_metric[API_CALL_RESPONSE_TOKENS_PER_SECOND_MEAN_METRIC_KEY])
for step_metric in step_metrics
if step_metric.get(API_CALL_RESPONSE_TOKENS_PER_SECOND_MEAN_METRIC_KEY) is not None
]
if execution_time_sec is None or execution_time_sec <= 0:
return {
"prompt_tokens_per_second": None,
"response_tokens_per_second": None,
"api_call_prompt_tokens_per_second": (
sum(api_call_prompt_tokens_per_second_values)
/ len(api_call_prompt_tokens_per_second_values)
if api_call_prompt_tokens_per_second_values
else None
),
"api_call_response_tokens_per_second": (
sum(api_call_response_tokens_per_second_values)
/ len(api_call_response_tokens_per_second_values)
if api_call_response_tokens_per_second_values
else None
),
}
prompt_token_total = 0.0
response_token_total = 0.0
for step_metric in step_metrics:
experience_count = step_metric.get(EXPERIENCE_COUNT_METRIC_KEY)
prompt_length_mean = step_metric.get(PROMPT_LENGTH_MEAN_METRIC_KEY)
response_length_mean = step_metric.get(RESPONSE_LENGTH_MEAN_METRIC_KEY)
if experience_count is None:
continue
if prompt_length_mean is not None:
prompt_token_total += float(experience_count) * float(prompt_length_mean)
if response_length_mean is not None:
response_token_total += float(experience_count) * float(response_length_mean)
return {
"prompt_tokens_per_second": prompt_token_total / float(execution_time_sec),
"response_tokens_per_second": response_token_total / float(execution_time_sec),
"api_call_prompt_tokens_per_second": (
sum(api_call_prompt_tokens_per_second_values)
/ len(api_call_prompt_tokens_per_second_values)
if api_call_prompt_tokens_per_second_values
else None
),
"api_call_response_tokens_per_second": (
sum(api_call_response_tokens_per_second_values)
/ len(api_call_response_tokens_per_second_values)
if api_call_response_tokens_per_second_values
else None
),
}