import argparse
import asyncio
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional
import streamlit as st
from transformers import AutoTokenizer
from trinity.buffer.storage.sql import SQLExperienceStorage
from trinity.common.config import StorageConfig
from trinity.common.experience import Experience
from trinity.common.experience_visualizer import build_experience_token_view
class _SyncViewerStorage:
"""Thin sync wrapper around async SQLExperienceStorage for Streamlit."""
def __init__(self, config: StorageConfig) -> None:
self._loop = asyncio.new_event_loop()
self._async = SQLExperienceStorage(config)
self._loop.run_until_complete(self._async.prepare())
def _run(self, coro):
return self._loop.run_until_complete(coro)
def close(self):
if self._loop and not self._loop.is_closed():
self._loop.run_until_complete(self._async.engine.dispose())
self._loop.close()
def __del__(self):
self.close()
def query(self, offset: int = 0, limit: int = 10, filters=None) -> List[Experience]:
return self._run(self._async.query(offset, limit, filters))
def count(self, filters=None) -> int:
return self._run(self._async.count(filters))
[docs]
class SQLExperienceViewer:
[docs]
def __init__(self, config: StorageConfig) -> None:
self.storage = _SyncViewerStorage(config)
[docs]
def get_experiences(
self, offset: int, limit: int = 10, filters: Optional[Dict] = None
) -> List[Experience]:
return self.storage.query(offset=offset, limit=limit, filters=filters)
[docs]
def total_experiences(self, filters: Optional[Dict] = None) -> int:
return self.storage.count(filters=filters)
[docs]
@staticmethod
def run_viewer(
model_path: str, db_url: str, table_name: str, schema_type: str, port: int
) -> None:
"""Start the Streamlit viewer.
Args:
model_path (str): Path to the tokenizer/model directory.
db_url (str): Database URL for the experience database.
table_name (str): Name of the experience table in the database.
schema_type (str): Schema type of the experience table.
port (int): Port number to run the Streamlit app on.
"""
from streamlit.web import cli
viewer_path = Path(__file__)
sys.argv = [
"streamlit",
"run",
str(viewer_path.resolve()),
"--server.port",
str(port),
"--server.fileWatcherType",
"none",
"--",
"--db-url",
db_url,
"--table",
table_name,
"--schema",
schema_type,
"--tokenizer",
model_path,
]
sys.exit(cli.main())
st.set_page_config(page_title="Trinity-RFT Experience Visualizer", layout="wide")
[docs]
def get_color_for_action_mask(action_mask_value: int) -> str:
if action_mask_value == 1:
return "#c8e6c9"
else:
return "#ffcdd2"
[docs]
def render_token_detail_html(html: str) -> None:
st.html(html)
[docs]
def render_experience(exp: Experience, tokenizer: Any) -> None:
"""Render a single experience in Streamlit."""
token_view = build_experience_token_view(exp, tokenizer)
def html_escape(text):
return (
text.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
.replace('"', """)
.replace("'", "'")
)
st.markdown("---")
# Header with EID
st.subheader(f"Experience [{exp.eid}]")
# Reward and metadata first (before prompt/response)
col_reward, col_metrics, col_info = st.columns(3)
with col_reward:
reward_val = exp.reward if exp.reward is not None else 0.0
st.markdown("**Reward**")
st.markdown(f"`{reward_val:.4f}`")
with col_metrics:
st.markdown("**Metrics**")
st.json(exp.metrics or {}, expanded=True)
with col_info:
st.markdown("**Info**")
st.json(exp.info or {}, expanded=False)
# Prompt (collapsed by default)
with st.expander("Prompt", expanded=False):
st.code(token_view.prompt_text, language=None, wrap_lines=True, line_numbers=True)
# Response (collapsed by default)
with st.expander("Response", expanded=False):
st.code(token_view.response_text, language=None, wrap_lines=True, line_numbers=True)
# Response Tokens Detail (collapsed by default)
with st.expander("Response Tokens Detail", expanded=False):
html = """
<style>
.token-detail-root * {
margin: 0;
padding: 0;
box-sizing: border-box;
}
.token-detail-root {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', sans-serif;
padding: 10px;
}
.token-detail-root .token-container {
display: flex;
flex-wrap: wrap;
gap: 4px;
padding: 12px;
background-color: #fafafa;
border-radius: 6px;
}
.token-detail-root .token-box {
display: inline-flex;
flex-direction: column;
align-items: center;
padding: 6px 10px;
border-radius: 4px;
border: 1px solid #e0e0e0;
min-width: 50px;
transition: transform 0.2s, box-shadow 0.2s;
}
.token-detail-root .token-box:hover {
transform: scale(1.5);
box-shadow: 0 4px 12px rgba(0,0,0,0.15);
z-index: 10;
}
.token-detail-root .token-text {
font-family: 'Courier New', monospace;
font-size: 13px;
font-weight: 600;
margin-bottom: 3px;
text-align: center;
word-break: break-all;
max-width: 90px;
}
.token-detail-root .token-logprob {
font-size: 10px;
color: #666;
font-family: 'Courier New', monospace;
text-align: center;
}
</style>
<div class="token-detail-root">
<div class="token-container">
"""
for token in token_view.response_tokens:
bg_color = get_color_for_action_mask(int(token.is_action))
token_display = token.token_text.replace(" ", "␣").replace("\n", "↵").replace("\t", "⇥")
token_display = html_escape(token_display)
logprob_text = f"{token.logprob:.4f}" if token.logprob is not None else "N/A"
html += f"""
<div class="token-box" style="background-color: {bg_color};">
<div class="token-text">{token_display}</div>
<div class="token-logprob">{logprob_text}</div>
</div>
"""
html += """
</div>
</div>
"""
render_token_detail_html(html)
[docs]
def parse_args():
parser = argparse.ArgumentParser(description="Experience Visualizer")
parser.add_argument("--db-url", type=str, help="Path to the experience database.")
parser.add_argument("--table", type=str, help="Name of the experience table.")
parser.add_argument(
"--schema",
type=str,
default="experience",
choices=("experience", "sft"),
help="Schema type of the experience table.",
)
parser.add_argument("--tokenizer", type=str, required=True, help="Path to the tokenizer.")
return parser.parse_args()
[docs]
@st.cache_resource
def get_viewer(db_url: str, table_name: str, schema_type: str) -> SQLExperienceViewer:
config = StorageConfig()
config.name = table_name
config.path = db_url
config.schema_type = schema_type
config.storage_type = "sql"
config.wrap_in_ray = False
return SQLExperienceViewer(config)
[docs]
def main(): # noqa: [C901]
args = parse_args()
viewer = get_viewer(args.db_url, args.table, args.schema)
st.title("Trinity-RFT Experience Visualizer")
# Initialize session state
if "page" not in st.session_state:
st.session_state.page = 1
# === Sidebar: Filters ===
st.sidebar.header("Filters")
# Reward range filter
st.sidebar.markdown("**Reward Range**")
col_rmin, col_rmax = st.sidebar.columns(2)
with col_rmin:
reward_min_str = st.text_input("Min", value="", key="reward_min")
with col_rmax:
reward_max_str = st.text_input("Max", value="", key="reward_max")
reward_min = float(reward_min_str) if reward_min_str.strip() else None
reward_max = float(reward_max_str) if reward_max_str.strip() else None
# Model version range filter
st.sidebar.markdown("**Model Version Range**")
col_vmin, col_vmax = st.sidebar.columns(2)
with col_vmin:
mv_min_str = st.text_input("Min", value="", key="mv_min")
with col_vmax:
mv_max_str = st.text_input("Max", value="", key="mv_max")
model_version_min = int(mv_min_str) if mv_min_str.strip() else None
model_version_max = int(mv_max_str) if mv_max_str.strip() else None
# Task ID exact match filter
task_id_filter = st.sidebar.text_input("Task ID (exact match)", value="", key="task_id")
# Apply filters button
if st.sidebar.button("Apply Filters", use_container_width=True):
new_filters: Dict = {}
if reward_min is not None:
new_filters["reward_min"] = reward_min
if reward_max is not None:
new_filters["reward_max"] = reward_max
if model_version_min is not None:
new_filters["model_version_min"] = int(model_version_min)
if model_version_max is not None:
new_filters["model_version_max"] = int(model_version_max)
if task_id_filter:
new_filters["task_id"] = task_id_filter
st.session_state.active_filters = new_filters
st.session_state.page = 1
st.rerun()
# Use committed filters from session state
if "active_filters" not in st.session_state:
st.session_state.active_filters = {}
filters: Dict = st.session_state.active_filters
# Sidebar bottom: per-page setting (low-profile)
st.sidebar.markdown("---")
experiences_per_page = st.sidebar.number_input(
"Per page", min_value=1, max_value=50, value=10, step=1, key="per_page"
)
# Query total with filters
total_seq_num = viewer.total_experiences(filters=filters or None)
total_pages = max(1, (total_seq_num + experiences_per_page - 1) // experiences_per_page)
# Clamp current page
if st.session_state.page > total_pages:
st.session_state.page = total_pages
# Calculate offset and fetch
offset = (st.session_state.page - 1) * experiences_per_page
experiences = viewer.get_experiences(offset, experiences_per_page, filters=filters or None)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
# Sidebar: table of contents
if experiences:
st.sidebar.markdown("---")
st.sidebar.markdown("**Contents**")
for exp in experiences:
eid_str = str(exp.eid)
reward_str = f"{exp.reward:.2f}" if exp.reward is not None else "N/A"
st.sidebar.markdown(
f"- [{eid_str}](#experience-{eid_str}) (r={reward_str})",
unsafe_allow_html=True,
)
# Render experiences
if experiences:
for exp in experiences:
st.markdown(f'<a name="experience-{exp.eid}"></a>', unsafe_allow_html=True)
render_experience(exp, tokenizer)
else:
st.info("No experiences found matching the current filters.")
# === Bottom: Pagination ===
st.markdown("---")
# Row 1: Previous | [current_page] / total_pages | Next
col_prev, col_page_input, col_slash, col_total, col_next = st.columns([1, 1, 0.3, 0.7, 1])
with col_prev:
if st.button("Previous", disabled=(st.session_state.page <= 1)):
st.session_state.page -= 1
st.rerun()
with col_page_input:
new_page = st.number_input(
"page",
min_value=1,
max_value=total_pages,
value=st.session_state.page,
step=1,
label_visibility="collapsed",
key="page_input",
)
if new_page != st.session_state.page:
st.session_state.page = new_page
st.rerun()
with col_slash:
st.markdown(
"<div style='text-align:center;line-height:38px;'>/</div>",
unsafe_allow_html=True,
)
with col_total:
st.markdown(
f"<div style='line-height:38px;'>{total_pages}</div>",
unsafe_allow_html=True,
)
with col_next:
if st.button("Next", disabled=(st.session_state.page >= total_pages)):
st.session_state.page += 1
st.rerun()
# Row 2: total count
st.caption(f"{total_seq_num} experiences in total")
if __name__ == "__main__":
main()