"""SQLAlchemy models for different data."""
from typing import Dict, Optional, Tuple
from sqlalchemy import JSON, Column, DateTime, Float, Integer, LargeBinary, String, func
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import declarative_base
from trinity.common.experience import Experience
from trinity.utils.log import get_logger
Base = declarative_base()
[docs]
class TaskModel(Base): # type: ignore
"""Model for storing tasks in SQLAlchemy."""
__abstract__ = True
id = Column(Integer, primary_key=True, autoincrement=True)
raw_task = Column(JSON, nullable=False)
[docs]
@classmethod
def from_dict(cls, dict: Dict):
return cls(raw_task=dict)
# ============================================================
# Experience Models (meta + blob split)
# ============================================================
[docs]
class ExperienceModel(Base): # type: ignore
"""SQLAlchemy model for Experience metadata."""
__abstract__ = True
id = Column(Integer, primary_key=True, autoincrement=True)
timestamp = Column(DateTime, server_default=func.now())
task_id = Column(String(64), nullable=True, index=True)
run_id = Column(Integer, nullable=True, index=True)
msg_id = Column(String(64), nullable=True, index=True)
model_version = Column(Integer, nullable=True, index=True)
reward = Column(Float, nullable=True, index=True)
consumed = Column(Integer, default=0, index=True)
[docs]
def to_experience(self, blob_bytes: bytes) -> Experience:
"""Load the experience from metadata + blob bytes."""
exp = Experience.deserialize(blob_bytes)
exp.eid.task = self.task_id
exp.eid.run = self.run_id
exp.eid.suffix = self.msg_id
exp.reward = self.reward
exp.info["model_version"] = self.model_version
return exp
[docs]
@classmethod
def from_experience(cls, experience: Experience):
"""Create meta row from experience (blob stored separately)."""
return cls(
reward=experience.reward,
task_id=str(experience.eid.task),
run_id=experience.eid.run,
msg_id=str(experience.eid.suffix),
model_version=experience.info.get("model_version"),
)
[docs]
class BlobModel(Base): # type: ignore
"""Unified blob storage model for all experience types."""
__abstract__ = True
id = Column(Integer, primary_key=True)
experience_bytes = Column(LargeBinary, nullable=False)
# ============================================================
# SFT Models (meta + blob split)
# ============================================================
[docs]
class SFTDataModel(Base): # type: ignore
"""SQLAlchemy model for SFT data metadata."""
__abstract__ = True
id = Column(Integer, primary_key=True, autoincrement=True)
message_list = Column(JSON, nullable=True)
[docs]
def to_experience(self, blob_bytes: bytes) -> Experience:
"""Load the experience from metadata + blob bytes."""
return Experience.deserialize(blob_bytes)
[docs]
@classmethod
def from_experience(cls, experience: Experience):
"""Create meta row from experience (blob stored separately)."""
return cls(
message_list=experience.messages,
)
# ============================================================
# DPO Models (meta + blob split)
# ============================================================
[docs]
class DPODataModel(Base): # type: ignore
"""SQLAlchemy model for DPO data metadata."""
__abstract__ = True
id = Column(Integer, primary_key=True, autoincrement=True)
chosen_message_list = Column(JSON, nullable=True)
rejected_message_list = Column(JSON, nullable=True)
[docs]
def to_experience(self, blob_bytes: bytes) -> Experience:
"""Load the experience from metadata + blob bytes."""
return Experience.deserialize(blob_bytes)
[docs]
@classmethod
def from_experience(cls, experience: Experience):
"""Create meta row from experience (blob stored separately)."""
return cls(
chosen_message_list=experience.chosen_messages,
rejected_message_list=experience.rejected_messages,
)
# ============================================================
# Engine initialization
# ============================================================
def _create_table_classes(table_name: str, schema_type: str):
"""Create dynamic table model classes for the given schema type.
Returns:
For task schema: (table_cls,)
For experience/sft/dpo schema: (meta_cls, blob_cls)
"""
from trinity.buffer.schema import SQL_SCHEMA
if schema_type is None:
schema_type = "task"
base_class = SQL_SCHEMA.get(schema_type)
if schema_type == "task":
table_attrs = {
"__tablename__": table_name,
"__abstract__": False,
"__table_args__": {"keep_existing": True},
}
table_cls = type(table_name, (base_class,), table_attrs)
return (table_cls,)
meta_attrs = {
"__tablename__": table_name,
"__abstract__": False,
"__table_args__": {"keep_existing": True},
}
meta_cls = type(f"{table_name}_meta", (base_class,), meta_attrs)
blob_table_name = f"{table_name}_blob"
blob_attrs = {
"__tablename__": blob_table_name,
"__abstract__": False,
"__table_args__": {"keep_existing": True},
}
blob_cls = type(f"{table_name}_blob", (BlobModel,), blob_attrs)
return (meta_cls, blob_cls)
[docs]
async def init_async_engine(db_url: str, table_name: str, schema_type: Optional[str]) -> Tuple:
"""Create an async SQLAlchemy engine and table classes.
Returns:
For task schema: (async_engine, table_cls)
For experience/sft/dpo schema: (async_engine, meta_cls, blob_cls)
"""
from trinity.buffer.utils import to_async_url
logger = get_logger(__name__)
async_url = to_async_url(db_url)
engine = create_async_engine(async_url, pool_pre_ping=True)
if schema_type is None:
schema_type = "task"
classes = _create_table_classes(table_name, schema_type)
try:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info(f"Created async tables for {table_name} (schema={schema_type}).")
except OperationalError:
logger.warning(f"Failed to create async tables for {table_name}, assuming they exist.")
if schema_type == "task":
return engine, classes[0]
return engine, classes[0], classes[1]