import argparse
import json
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Optional
import matplotlib.pyplot as plt
import streamlit as st
try:
from streamlit.runtime.scriptrunner import get_script_run_ctx
except ImportError: # pragma: no cover - fallback for streamlit runtime layout changes
from streamlit.runtime.scriptrunner_utils.script_run_context import (
get_script_run_ctx,
)
STEP_METRIC_PREFIXES_BY_MODULE: dict[str, list[str]] = {
"explorer": [
"rollout/time/run_execution/mean",
"rollout/time/task_execution/mean",
"rollout/prompt_length/mean",
"rollout/response_length/mean",
"rollout/api_call_prompt_tokens_per_second/mean",
"rollout/api_call_response_tokens_per_second/mean",
"experience_pipeline/experience_count",
],
"trainer": [],
}
MEMORY_SERIES_KEY = "memory_rss_mb"
[docs]
class PerfReportViewer:
[docs]
@staticmethod
def run_viewer(report_path: str, port: int) -> None:
"""Start the Streamlit perf report viewer."""
from streamlit.web import cli
viewer_path = Path(__file__)
sys.argv = [
"streamlit",
"run",
str(viewer_path.resolve()),
"--server.port",
str(port),
"--server.fileWatcherType",
"none",
"--",
"--report",
report_path,
]
sys.exit(cli.main())
[docs]
def launch_report_viewer(report_path: str, port: int) -> None:
"""Launch the Streamlit perf report viewer from another CLI entrypoint."""
PerfReportViewer.run_viewer(report_path, port)
[docs]
def has_streamlit_context() -> bool:
return get_script_run_ctx() is not None
[docs]
def configure_streamlit_page() -> None:
if has_streamlit_context():
st.set_page_config(page_title="Trinity Performance Report", layout="wide")
[docs]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Trinity Performance Report Viewer")
parser.add_argument("--report", type=str, required=True, help="Path to the perf report JSON.")
parser.add_argument(
"--port",
type=int,
default=8503,
help="Port used when auto-launching the Streamlit report viewer.",
)
return parser.parse_args()
[docs]
def load_report(report_path: str) -> dict[str, Any]:
report_file = Path(report_path)
if not report_file.exists():
raise FileNotFoundError(f"Report file not found: {report_path}")
with report_file.open("r", encoding="utf-8") as handle:
return json.load(handle)
[docs]
def infer_module_name(report: dict[str, Any]) -> str:
run_meta = report.get("run_meta", {})
return str(run_meta.get("module"))
[docs]
def get_step_metric_prefixes(report: dict[str, Any]) -> list[str]:
module_name = infer_module_name(report)
return STEP_METRIC_PREFIXES_BY_MODULE.get(module_name, [])
[docs]
def metric_label(metric_name: str) -> str:
return metric_name.replace("_", " ").title()
[docs]
def gpu_series_label(gpu_payload: dict[str, Any]) -> str:
gpu_id = gpu_payload.get("gpu_id", "?")
gpu_name = gpu_payload.get("name")
if gpu_name:
return f"GPU {gpu_id} ({gpu_name})"
return f"GPU {gpu_id}"
[docs]
def render_metric_card(metric_name: str, value: Any) -> None:
display_value = format_metric_value(value)
label = metric_label(metric_name)
st.markdown(
f"""
<div style="
background: linear-gradient(135deg, #eef6ff 0%, #f7fbff 100%);
border: 1px solid #d7e6fb;
border-radius: 14px;
padding: 16px 18px;
min-height: 108px;
box-shadow: 0 1px 3px rgba(15, 23, 42, 0.08);
">
<div style="font-size: 0.9rem; color: #4a5a70; margin-bottom: 10px;">{label}</div>
<div style="font-size: 1.6rem; font-weight: 700; color: #0f172a;">{display_value}</div>
</div>
""",
unsafe_allow_html=True,
)
[docs]
def build_elapsed_series(series: list[dict[str, Any]]) -> tuple[list[float], list[float]]:
if not series:
return [], []
start_timestamp = float(series[0]["timestamp"])
x_values = [float(point["timestamp"]) - start_timestamp for point in series]
y_values = [float(point["value"]) for point in series]
return x_values, y_values
[docs]
def build_scalar_timeline_series(
timeline: list[dict[str, Any]], metric_key: str
) -> list[dict[str, float]]:
return [
{"timestamp": sample["timestamp"], "value": sample[metric_key]}
for sample in timeline
if sample.get(metric_key) is not None
]
[docs]
def build_gpu_timeline_series(
timeline: list[dict[str, Any]], metric_key: str
) -> dict[str, dict[str, Any]]:
series_by_gpu: dict[str, dict[str, Any]] = {}
for sample in timeline:
timestamp = sample.get("timestamp")
for gpu_sample in sample.get("gpu_metrics", []):
if gpu_sample.get(metric_key) is None:
continue
gpu_key = str(gpu_sample.get("gpu_id"))
gpu_payload = series_by_gpu.setdefault(
gpu_key,
{
"gpu_id": gpu_sample.get("gpu_id"),
"name": gpu_sample.get("name"),
"values": [],
},
)
gpu_payload["values"].append({"timestamp": timestamp, "value": gpu_sample[metric_key]})
return series_by_gpu
[docs]
def render_line_chart(
title: str,
x_values: list[float],
y_series: dict[str, list[float]],
y_label: str,
legend_below: bool = False,
legend_columns: int = 1,
) -> None:
st.markdown(f"#### {title}")
if not x_values or not y_series:
st.info(f"No data for {title}.")
return
figure, axis = plt.subplots(figsize=(6, 2.6))
for series_name, y_values in y_series.items():
axis.plot(x_values[: len(y_values)], y_values, label=series_name)
axis.set_xlabel("Elapsed Time (s)")
axis.set_ylabel(y_label)
axis.grid(True, alpha=0.3)
if len(y_series) > 1:
if legend_below:
axis.legend(
loc="upper center",
bbox_to_anchor=(0.5, -0.28),
ncol=min(legend_columns, len(y_series)),
frameon=False,
fontsize=8,
)
figure.subplots_adjust(bottom=0.32)
else:
axis.legend()
st.pyplot(figure, clear_figure=True)
[docs]
def render_step_metric_chart(step_metrics: list[dict[str, Any]], metric_key: str) -> None:
x_values = [
int(step_metric["step"]) for step_metric in step_metrics if metric_key in step_metric
]
y_values = [
float(step_metric[metric_key])
for step_metric in step_metrics
if step_metric.get(metric_key) is not None
]
st.markdown(f"#### {metric_label(metric_key)}")
if not x_values or not y_values:
st.info(f"No data for {metric_key}.")
return
figure, axis = plt.subplots(figsize=(6, 2.6))
axis.plot(x_values[: len(y_values)], y_values, marker="o")
axis.set_xlabel("Step")
axis.set_ylabel(metric_label(metric_key))
axis.grid(True, alpha=0.3)
st.pyplot(figure, clear_figure=True)
[docs]
def compute_global_token_throughput_metrics(report: dict[str, Any]) -> dict[str, float | None]:
timing = report.get("timing", {})
return {
"prompt_tokens_per_second": timing.get("prompt_tokens_per_second"),
"response_tokens_per_second": timing.get("response_tokens_per_second"),
"api_call_prompt_tokens_per_second": timing.get("api_call_prompt_tokens_per_second"),
"api_call_response_tokens_per_second": timing.get("api_call_response_tokens_per_second"),
}
[docs]
def render_global_metrics(report: dict[str, Any]) -> None:
st.header("Global Metrics")
timing = report.get("timing", {})
metric_items: list[tuple[str, Any]] = []
metric_items.extend(
(
metric_key,
timing.get(metric_key),
)
for metric_key in ("startup_time_sec", "execution_time_sec")
)
metric_items.extend(compute_global_token_throughput_metrics(report).items())
shown_items = [(key, value) for key, value in metric_items if value is not None]
if not shown_items:
st.info("No global metrics found in this report.")
return
for row_start in range(0, len(shown_items), 2):
row_items = shown_items[row_start : row_start + 2]
columns = st.columns(len(row_items))
for column, (metric_key, value) in zip(columns, row_items):
with column:
render_metric_card(metric_key, value)
[docs]
def render_step_metrics(report: dict[str, Any]) -> None:
st.header("Step Metrics")
step_metrics = report.get("step_metrics", [])
if not step_metrics:
st.info("No step metrics found in this report.")
return
metric_prefixes = get_step_metric_prefixes(report)
metric_keys: list[str] = []
for step_metric in step_metrics:
for metric_key, metric_value in step_metric.items():
if metric_key in {"step", "raw_metrics"} or metric_value is None:
continue
if any(metric_key.startswith(prefix) for prefix in metric_prefixes):
if metric_key not in metric_keys:
metric_keys.append(metric_key)
if not metric_keys:
st.info("No configured step metrics matched the current report.")
return
for metric_index in range(0, len(metric_keys), 2):
columns = st.columns(2)
for column_index, metric_key in enumerate(metric_keys[metric_index : metric_index + 2]):
with columns[column_index]:
render_step_metric_chart(step_metrics, metric_key)
with st.expander("Step Metrics Table"):
compact_rows = []
for step_metric in step_metrics:
compact_row = {key: value for key, value in step_metric.items() if key != "raw_metrics"}
compact_rows.append(compact_row)
st.dataframe(compact_rows, use_container_width=True)
[docs]
def render_resource_utilization(report: dict[str, Any]) -> None:
st.header("Resource Utilization")
resource_timeline = report.get("resource_timeline", [])
cpu_series = build_scalar_timeline_series(resource_timeline, "cpu_percent")
cpu_x, cpu_y = build_elapsed_series(cpu_series)
memory_series = build_scalar_timeline_series(resource_timeline, MEMORY_SERIES_KEY)
memory_x, memory_y = build_elapsed_series(memory_series)
gpu_util_series = build_gpu_timeline_series(resource_timeline, "gpu_util_percent")
gpu_util_x: list[float] = []
gpu_util_y: dict[str, list[float]] = {}
for gpu_payload in gpu_util_series.values():
gpu_util_x, values = build_elapsed_series(gpu_payload.get("values", []))
gpu_util_y[gpu_series_label(gpu_payload)] = values
gpu_memory_series = build_gpu_timeline_series(resource_timeline, "gpu_memory_used_mb")
gpu_memory_x: list[float] = []
gpu_memory_y: dict[str, list[float]] = {}
for gpu_payload in gpu_memory_series.values():
gpu_memory_x, values = build_elapsed_series(gpu_payload.get("values", []))
gpu_memory_y[gpu_series_label(gpu_payload)] = values
first_row = st.columns(2)
with first_row[0]:
render_line_chart("CPU Utilization", cpu_x, {"CPU": cpu_y}, "CPU %")
with first_row[1]:
render_line_chart("Memory Usage", memory_x, {"Memory": memory_y}, "MB")
second_row = st.columns(2)
with second_row[0]:
render_line_chart(
"GPU Utilization",
gpu_util_x,
gpu_util_y,
"GPU %",
legend_below=True,
legend_columns=2,
)
with second_row[1]:
render_line_chart(
"GPU Memory Usage",
gpu_memory_x,
gpu_memory_y,
"MB",
legend_below=True,
legend_columns=2,
)
[docs]
def main(args: Optional[argparse.Namespace] = None) -> None:
configure_streamlit_page()
if args is None:
args = parse_args()
try:
report = load_report(args.report)
except (FileNotFoundError, json.JSONDecodeError, OSError) as error:
st.title("Trinity Perf Report Viewer")
st.error(str(error))
return
render_header(report, args.report)
render_global_metrics(report)
render_step_metrics(report)
render_resource_utilization(report)
if __name__ == "__main__":
parsed_args = parse_args()
if has_streamlit_context():
main(parsed_args)
else:
launch_report_viewer(parsed_args.report, parsed_args.port)