Source code for trinity.buffer.storage.sql

"""SQL database storage for experience and task buffers.

Primary classes (async, used by Ray actors and production hot paths):
    SQLExperienceStorage, SQLTaskStorage

Factory:
    SQLStorage.get_wrapper(config) — returns Ray actor handle.
"""

import asyncio
import os
import time
from typing import Dict, List, Optional

import ray
from datasets import Dataset
from sqlalchemy import and_, asc, desc, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker

from trinity.buffer.schema import FORMATTER
from trinity.buffer.schema.sql_schema import init_async_engine
from trinity.buffer.utils import async_run_with_retry_session
from trinity.common.config import StorageConfig
from trinity.common.constants import MAX_EXP_BYTES_ENV_VAR
from trinity.common.experience import Experience
from trinity.common.rewards import REWARD_FUNCTIONS
from trinity.common.workflows import WORKFLOWS, Task
from trinity.utils.log import get_logger

# ---------------------------------------------------------------------------
# Async primary implementations
# ---------------------------------------------------------------------------


[docs] class SQLExperienceStorage: """Primary async SQL storage for experiences. Used directly as Ray actor."""
[docs] def __init__(self, config: StorageConfig) -> None: self.logger = get_logger(f"sql_{config.name}") self.config = config self.max_timeout = config.max_read_timeout self.batch_size = config.batch_size self.enable_replay = config.replay_buffer is not None and config.replay_buffer.enable self.max_experience_bytes = int(os.getenv(MAX_EXP_BYTES_ENV_VAR, 1024 * 1024 * 32)) self.max_retry_times = config.max_retry_times self.max_retry_interval = config.max_retry_interval self.ref_count = 0 self.stopped = False self.offset = config.index self._initialized = False if config.schema_type == "experience": self._read_method = self._read_priority else: self._read_method = self._read_fifo
[docs] async def prepare(self) -> None: """Initialize async engine and create tables.""" if self._initialized: return result = await init_async_engine( self.config.path, self.config.name, self.config.schema_type # type: ignore ) self.engine, self.table_model_cls, self.blob_model_cls = result self.session = async_sessionmaker(self.engine, expire_on_commit=False) self._initialized = True self.logger.info(f"SQL storage initialized at {self.config.path}")
[docs] async def write(self, data: List[Experience]) -> None: await self.prepare() async def operation(session: AsyncSession): for exp in data: exp_bytes = exp.serialize() if ( self.max_experience_bytes > 0 and exp_bytes is not None and len(exp_bytes) > self.max_experience_bytes ): self.logger.warning( f"Experience size {len(exp_bytes)} bytes exceeds " f"max_experience_bytes {self.max_experience_bytes}, skipping." ) continue meta_row = self.table_model_cls.from_experience(exp) session.add(meta_row) await session.flush() blob_row = self.blob_model_cls(id=meta_row.id, experience_bytes=exp_bytes) session.add(blob_row) await async_run_with_retry_session( self.session, operation, self.max_retry_times, self.max_retry_interval ) self.logger.info(f"Write {len(data)} experiences to SQL storage.")
async def _fetch_blobs(self, session: AsyncSession, ids: List[int]) -> Dict[int, bytes]: stmt = select(self.blob_model_cls).where(self.blob_model_cls.id.in_(ids)) result = await session.execute(stmt) blobs = result.scalars().all() return {b.id: b.experience_bytes for b in blobs} def _assemble_experiences(self, meta_rows, blob_map: Dict[int, bytes]) -> List[Experience]: experiences = [] for row in meta_rows: blob_bytes = blob_map.get(row.id) if blob_bytes is None: self.logger.warning(f"Missing blob for experience id={row.id}, skipping.") continue experiences.append(row.to_experience(blob_bytes)) return experiences async def _read_fifo(self, batch_size: int) -> List[Experience]: exp_list = [] start_time = time.time() while len(exp_list) < batch_size: if self.stopped: raise StopAsyncIteration() if time.time() - start_time > self.max_timeout: self.logger.warning( f"Max read timeout reached ({self.max_timeout} s), " f"only got {len(exp_list)} experiences, stopping..." ) raise StopAsyncIteration() current_offset = self.offset remaining = batch_size - len(exp_list) async def operation(session: AsyncSession): stmt = ( select(self.table_model_cls) .where(self.table_model_cls.id > current_offset) .order_by(asc(self.table_model_cls.id)) .limit(remaining) ) result = await session.execute(stmt) meta_rows = result.scalars().all() if not meta_rows: return [], None ids = [row.id for row in meta_rows] blob_map = await self._fetch_blobs(session, ids) return ( self._assemble_experiences(meta_rows, blob_map), meta_rows[-1].id, ) experiences, next_offset = await async_run_with_retry_session( self.session, operation, self.max_retry_times, self.max_retry_interval ) if next_offset is not None: self.offset = next_offset start_time = time.time() exp_list.extend(experiences) if len(exp_list) < batch_size: self.logger.info(f"Waiting for {batch_size - len(exp_list)} more experiences...") await asyncio.sleep(1) return exp_list async def _read_priority(self, batch_size: int, min_model_version: int = 0) -> List[Experience]: exp_list = [] start_time = time.time() latest_size = 0 while latest_size < batch_size: if self.stopped: raise StopAsyncIteration() if time.time() - start_time > self.max_timeout: self.logger.warning( f"Max read timeout reached ({self.max_timeout} s), " f"only got {latest_size} experiences, stopping..." ) raise StopAsyncIteration() enable_replay = self.enable_replay table_cls = self.table_model_cls is_sqlite = self.engine.dialect.name == "sqlite" async def operation(session: AsyncSession): stmt = select(table_cls) if min_model_version > 0: stmt = stmt.where(table_cls.model_version >= min_model_version) if not enable_replay: stmt = stmt.where(table_cls.consumed == 0) stmt = stmt.order_by(asc(table_cls.consumed), desc(table_cls.id)).limit(batch_size) if not is_sqlite: stmt = stmt.with_for_update() result = await session.execute(stmt) meta_rows = result.scalars().all() if len(meta_rows) != batch_size: return len(meta_rows), False, [] ids = [row.id for row in meta_rows] update_stmt = ( update(table_cls) .where(table_cls.id.in_(ids)) .values(consumed=table_cls.consumed + 1) ) await session.execute(update_stmt) blob_map = await self._fetch_blobs(session, ids) return ( len(meta_rows), True, self._assemble_experiences(meta_rows, blob_map), ) latest_batch_size, has_full_batch, experiences = await async_run_with_retry_session( self.session, operation, self.max_retry_times, self.max_retry_interval ) if not has_full_batch: if latest_size != latest_batch_size: latest_size = latest_batch_size start_time = time.time() else: exp_list.extend(experiences) break self.logger.info(f"Waiting for {batch_size - len(exp_list)} more experiences...") await asyncio.sleep(1) return exp_list
[docs] async def read(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]: await self.prepare() if self.stopped: raise StopAsyncIteration() batch_size = self.batch_size if batch_size is None else batch_size return await self._read_method(batch_size, **kwargs)
def _build_filter_conditions(self, filters: Optional[Dict] = None): conditions = [] if not filters: return conditions if filters.get("reward_min") is not None: conditions.append(self.table_model_cls.reward >= filters["reward_min"]) if filters.get("reward_max") is not None: conditions.append(self.table_model_cls.reward <= filters["reward_max"]) if filters.get("model_version_min") is not None: conditions.append(self.table_model_cls.model_version >= filters["model_version_min"]) if filters.get("model_version_max") is not None: conditions.append(self.table_model_cls.model_version <= filters["model_version_max"]) if filters.get("task_id"): conditions.append(self.table_model_cls.task_id == filters["task_id"]) return conditions
[docs] async def count(self, filters: Optional[Dict] = None) -> int: await self.prepare() async def operation(session: AsyncSession): stmt = select(func.count()).select_from(self.table_model_cls) conditions = self._build_filter_conditions(filters) if conditions: stmt = stmt.where(and_(*conditions)) result = await session.execute(stmt) return result.scalar() return await async_run_with_retry_session( self.session, operation, self.max_retry_times, self.max_retry_interval )
[docs] async def query( self, offset: int = 0, limit: int = 10, filters: Optional[Dict] = None ) -> List[Experience]: await self.prepare() async def operation(session: AsyncSession): stmt = select(self.table_model_cls) conditions = self._build_filter_conditions(filters) if conditions: stmt = stmt.where(and_(*conditions)) stmt = stmt.offset(offset).limit(limit) result = await session.execute(stmt) meta_rows = result.scalars().all() if not meta_rows: return [] ids = [row.id for row in meta_rows] blob_map = await self._fetch_blobs(session, ids) return self._assemble_experiences(meta_rows, blob_map) return await async_run_with_retry_session( self.session, operation, self.max_retry_times, self.max_retry_interval )
[docs] @classmethod async def load_from_dataset( cls, dataset: Dataset, config: StorageConfig ) -> "SQLExperienceStorage": storage = cls(config) await storage.prepare() formatter = FORMATTER.get(config.schema_type)( tokenizer_path=config.tokenizer_path, format_config=config.format ) batch_size = storage.batch_size batch = [] for item in dataset: batch.append(formatter.format(item)) if len(batch) >= batch_size: await storage.write(batch) batch.clear() if batch: await storage.write(batch) return storage
[docs] def acquire(self) -> int: self.ref_count += 1 return self.ref_count
[docs] def release(self) -> int: self.ref_count -= 1 if self.ref_count <= 0: self.stopped = True return self.ref_count
[docs] class SQLTaskStorage: """Primary async SQL storage for tasks. Used directly as Ray actor."""
[docs] def __init__(self, config: StorageConfig) -> None: self.logger = get_logger(f"sql_{config.name}") self.config = config self.batch_size = config.batch_size self.is_eval = config.is_eval self.max_retry_times = config.max_retry_times self.max_retry_interval = config.max_retry_interval self.ref_count = 0 self.stopped = False self.offset = config.index self._initialized = False if config.total_steps: self.total_samples = self.batch_size * config.total_steps else: self.total_samples = float("inf")
[docs] async def prepare(self) -> None: """Initialize async engine and create tables.""" if self._initialized: return from trinity.buffer.schema.formatter import TaskFormatter result = await init_async_engine( self.config.path, self.config.name, self.config.schema_type # type: ignore ) self.engine, self.table_model_cls = result self.session = async_sessionmaker(self.engine, expire_on_commit=False) self.default_workflow_cls = WORKFLOWS.get(self.config.default_workflow_type) self.default_reward_fn_cls = REWARD_FUNCTIONS.get(self.config.default_reward_fn_type) self.formatter = TaskFormatter(self.config) self._initialized = True self.logger.info(f"SQL task storage initialized at {self.config.path}")
[docs] async def write(self, data: List[Dict]) -> None: await self.prepare() async def operation(session: AsyncSession): tasks = [self.table_model_cls.from_dict(item) for item in data] session.add_all(tasks) await async_run_with_retry_session( self.session, operation, self.max_retry_times, self.max_retry_interval )
[docs] async def read(self, batch_size: Optional[int] = None) -> List[Task]: await self.prepare() if self.stopped: raise StopAsyncIteration() if self.offset > self.total_samples: raise StopAsyncIteration() batch_size = self.batch_size if batch_size is None else batch_size table_cls = self.table_model_cls async def operation(session: AsyncSession): stmt = ( select(table_cls) .where(table_cls.id > self.offset) .order_by(asc(table_cls.id)) .limit(batch_size) ) result = await session.execute(stmt) results = result.scalars().all() if len(results) == 0: raise StopAsyncIteration() if not self.is_eval and len(results) < batch_size: raise StopAsyncIteration() return results[-1].id, [self.formatter.format(item.raw_task) for item in results] self.offset, tasks = await async_run_with_retry_session( self.session, operation, self.max_retry_times, self.max_retry_interval ) return tasks
[docs] @classmethod async def load_from_dataset(cls, dataset: Dataset, config: StorageConfig) -> "SQLTaskStorage": storage = cls(config) await storage.prepare() batch_size = config.batch_size batch = [] for item in dataset: batch.append(item) if len(batch) >= batch_size: await storage.write(batch) batch.clear() if batch: await storage.write(batch) return storage
[docs] def acquire(self) -> int: self.ref_count += 1 return self.ref_count
[docs] def release(self) -> int: self.ref_count -= 1 if self.ref_count <= 0: self.stopped = True return self.ref_count
# --------------------------------------------------------------------------- # Factory # ---------------------------------------------------------------------------
[docs] class SQLStorage: """Factory for creating SQL storage Ray actors."""
[docs] @classmethod def get_wrapper(cls, config: StorageConfig): if config.schema_type is None: async_cls = SQLTaskStorage else: async_cls = SQLExperienceStorage return ( ray.remote(async_cls) .options( name=f"sql-{config.name}", namespace=config.ray_namespace or ray.get_runtime_context().namespace, get_if_exists=True, max_concurrency=5, ) .remote(config) )