trinity.explorer.workflow_runner 源代码

# -*- coding: utf-8 -*-
"""The Workflow Runner Module."""
import asyncio
import os
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

from trinity.buffer import get_buffer_reader, get_buffer_writer
from trinity.common.config import Config, StorageConfig
from trinity.common.constants import LOG_DIR_ENV_VAR, LOG_LEVEL_ENV_VAR
from trinity.common.experience import Experience
from trinity.common.models import get_debug_explorer_model
from trinity.common.models.model import InferenceModel, ModelWrapper
from trinity.common.workflows import Task, Workflow
from trinity.utils.log import get_logger


[文档] @dataclass(frozen=True) class Status: """Status of the task running result.""" ok: bool metrics: List[Dict[str, float]] # A list of metric dictionaries, where each dictionary is from a single run. message: Optional[str] = None
[文档] def calculate_run_level_metrics(experiences: List[Experience]) -> Dict[str, float]: """Calculate metrics from experiences. For non-repeatable workflows, this function will average the metrics from experiences generated by each run, which is equivalent to calculating run level metrics. For repeatable workflows, please do not use this function. """ run_level_metrics: Dict[str, List[float]] = defaultdict(list) for exp in experiences: if exp.metrics: for k, v in exp.metrics.items(): run_level_metrics[k].append(v) averaged_metrics: Dict[str, float] = {} for key, values in run_level_metrics.items(): averaged_metrics[key] = sum(values) / len(values) return averaged_metrics
[文档] class WorkflowRunner: """A Ray remote actor to run the workflow and generate experiences."""
[文档] def __init__( self, config: Config, model: InferenceModel, auxiliary_models: Optional[List[InferenceModel]] = None, runner_id: Optional[int] = None, ) -> None: self.name = f"{config.explorer.name}_runner_{runner_id}" self.logger = get_logger(self.name, in_ray_actor=True) self.config = config self.model = model self.model_wrapper = ModelWrapper( model, enable_lora=config.explorer.rollout_model.enable_lora, enable_history=config.explorer.rollout_model.enable_history, ) self.auxiliary_models = auxiliary_models or [] self.auxiliary_model_wrappers = [ ModelWrapper( model, ) for model, aux_model_config in zip( self.auxiliary_models, config.explorer.auxiliary_models ) ] self.workflow_instance: Workflow = None self.runner_id = runner_id self.runner_state = { "workflow_id": None, "model_version": None, "begin_time": 0, "terminate_time": 0, } self.concurrent_mode = config.explorer.concurrent_mode if self.concurrent_mode == "sequential": self.concurrent_run_fn = self._sequential_run elif self.concurrent_mode == "asynchronous": self.concurrent_run_fn = self._asynchronous_run elif self.concurrent_mode == "multi-threading": self.concurrent_run_fn = self._multi_threading_run else: self.logger.warning( f"Unknown concurrent_mode {self.concurrent_mode}, defaulting to sequential." ) self.concurrent_run_fn = self._sequential_run self.logger.info( f"WorkflowRunner [{self.name}]({self.concurrent_mode}) initialized:\n" f" > rollout model: {self.config.explorer.rollout_model.model_path}\n" f" > auxiliary models: {[aux_model_config.model_path for aux_model_config in self.config.explorer.auxiliary_models]}" )
[文档] async def prepare(self) -> None: """Prepare the runner.""" await asyncio.gather( self.model_wrapper.prepare(), *(aux_model.prepare() for aux_model in self.auxiliary_model_wrappers), ) self.logger.info(f"WorkflowRunner [{self.name}] is prepared and ready to run tasks.")
[文档] def is_alive(self): return True
def _create_workflow_instance(self, task: Task) -> Workflow: if task.workflow is None: raise ValueError("Workflow is not set in the task.") if ( self.workflow_instance is None or not self.workflow_instance.__class__ == task.workflow or not self.workflow_instance.resettable ): # Pass ModelWrapper directly; Workflow.__init__ will get OpenAI clients automatically self.workflow_instance = task.to_workflow( self.model_wrapper, self.auxiliary_model_wrappers, ) else: self.workflow_instance.reset(task) return self.workflow_instance async def _run_workflow(self, workflow_instance: Workflow) -> List[Experience]: if workflow_instance.asynchronous: exps = await workflow_instance.run_async() else: exps = workflow_instance.run() return exps async def _run_task( self, task: Task, repeat_times: int, run_id_base: int ) -> Tuple[List[Experience], List[Dict]]: """Init workflow from the task and run it.""" if task.workflow.can_repeat: workflow_instance = self._create_workflow_instance(task) workflow_instance.set_repeat_times(repeat_times, run_id_base) st = time.time() await self.model_wrapper.clean_workflow_state() self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{run_id_base}" self.runner_state["terminate_time"] = None self.runner_state["begin_time"] = st exps = await self._run_workflow(workflow_instance) et = time.time() self.runner_state["terminate_time"] = et # repeatable workflow cannot calculate run level metrics, we use experience level metrics directly run_metrics = [exp.metrics for exp in exps if exp.metrics] for metric in run_metrics: metric["time/run_execution"] = et - st else: exps, run_metrics = await self.concurrent_run_fn(task, repeat_times, run_id_base) return exps, run_metrics async def _sequential_run( self, task: Task, repeat_times: int, run_id_base: int, ) -> Tuple[List[Experience], List[Dict]]: exps = [] run_metrics = [] for i in range(repeat_times): st = time.time() workflow = self._create_workflow_instance(task) await self.model_wrapper.clean_workflow_state() self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" self.runner_state["terminate_time"] = None self.runner_state["begin_time"] = st new_exps = await self._run_workflow(workflow) et = time.time() self.runner_state["terminate_time"] = et run_metric = calculate_run_level_metrics(new_exps) run_metric["time/run_execution"] = et - st run_metrics.append(run_metric) for exp in new_exps: exp.eid.run = run_id_base + i exps.extend(new_exps) return exps, run_metrics async def _asynchronous_run( self, task: Task, repeat_times: int, run_id_base: int, ) -> Tuple[List[Experience], List[Dict]]: async def run_single(i: int) -> Tuple[List[Experience], Dict]: st = time.time() workflow = task.to_workflow( self.model_wrapper.clone_with_isolated_history() if self.config.explorer.rollout_model.enable_history else self.model_wrapper, self.auxiliary_model_wrappers, ) await self.model_wrapper.clean_workflow_state() self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" self.runner_state["terminate_time"] = None self.runner_state["begin_time"] = st new_exps = await self._run_workflow(workflow) et = time.time() self.runner_state["terminate_time"] = et run_metric = calculate_run_level_metrics(new_exps) run_metric["time/run_execution"] = et - st for exp in new_exps: exp.eid.run = run_id_base + i return new_exps, run_metric tasks = [run_single(i) for i in range(repeat_times)] results = await asyncio.gather(*tasks) exps = [] run_metrics = [] for new_exps, run_metric in results: exps.extend(new_exps) run_metrics.append(run_metric) return exps, run_metrics async def _multi_threading_run( self, task: Task, repeat_times: int, run_id_base: int, ) -> Tuple[List[Experience], List[Dict]]: async def run_single(i: int) -> Tuple[List[Experience], Dict]: st = time.time() await self.model_wrapper.clean_workflow_state() workflow = task.to_workflow( self.model_wrapper.clone_with_isolated_history() if self.config.explorer.rollout_model.enable_history else self.model_wrapper, self.auxiliary_model_wrappers, ) self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" self.runner_state["terminate_time"] = None self.runner_state["begin_time"] = st new_exps = await self._run_workflow(workflow) et = time.time() self.runner_state["terminate_time"] = et run_metric = calculate_run_level_metrics(new_exps) run_metric["time/run_execution"] = et - st for exp in new_exps: exp.eid.run = run_id_base + i return new_exps, run_metric # Use asyncio.to_thread to run async tasks in threads results = await asyncio.gather( *( asyncio.to_thread(lambda idx=i: asyncio.run(run_single(idx))) # type: ignore[misc] for i in range(repeat_times) ) ) exps = [] run_metrics = [] for new_exps, run_metric in results: exps.extend(new_exps) run_metrics.append(run_metric) return exps, run_metrics
[文档] async def get_runner_state(self) -> Dict: """Get the runner state.""" runner_state = self.runner_state.copy() runner_state.update(await self.model_wrapper.get_workflow_state()) return runner_state
[文档] async def run_task( self, task: Task, batch_id: str, repeat_times: int = 1, run_id_base: int = 0, ) -> Tuple[Status, List[Experience]]: """Run the task and return the states.""" # TODO: avoid sending the experiences back to the scheduler to reduce the communication overhead try: st = time.time() model_version = await self.model_wrapper.model_version_async self.runner_state["model_version"] = model_version self.logger.info( f"Starting task: step={batch_id}, model_version={model_version}, repeat_times={repeat_times}, run_id_base={run_id_base}" ) exps, metrics = await self._run_task(task, repeat_times, run_id_base) assert exps is not None and len(exps) > 0, "An empty experience is generated" # set eid for each experience for exp in exps: exp.eid.batch = task.batch_id # keep exp.eid.task if it has been set before (e.g., in workflow) if exp.eid.task == "": # "" is the default value exp.eid.task = task.task_id if not hasattr(exp, "info") or exp.info is None: exp.info = {} exp.info["model_version"] = model_version exp.info["use_count"] = 0 exp.info["task_index"] = task.index if not hasattr(exp, "metrics") or exp.metrics is None: exp.metrics = {} if task.is_eval: # If the task is an evaluation task, we do not record the experiences to the buffer return Status(True, metrics=metrics), [] else: return Status(True, metrics=metrics), exps except Exception as e: error_trace_back = traceback.format_exc() self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}") return ( Status(False, metrics=[{"time/run_execution": time.time() - st}], message=str(e)), [], )
[文档] class DebugWorkflowRunner(WorkflowRunner): """A WorkflowRunner for debugging."""
[文档] def __init__( self, config: Config, output_dir: str = "debug_output", enable_profiling: bool = False, disable_overwrite: bool = False, ) -> None: model, auxiliary_models = get_debug_explorer_model(config) if disable_overwrite: # if output dir is not empty, change to a new dir with datetime suffix if os.path.isdir(output_dir) and os.listdir(output_dir): suffix = time.strftime("%Y%m%d%H%M%S", time.localtime()) output_dir = f"{output_dir}_{suffix}" os.environ[LOG_DIR_ENV_VAR] = os.path.join(output_dir, "log") os.environ[LOG_LEVEL_ENV_VAR] = "DEBUG" super().__init__(config, model, auxiliary_models, 0) self.taskset = get_buffer_reader(config.buffer.explorer_input.tasksets[0]) self.output_dir = output_dir self.enable_profiling = enable_profiling self.logger.info(f"Debug output directory: {self.output_dir}") os.makedirs(self.output_dir, exist_ok=True) self.output_profiling_file = os.path.join( self.output_dir, "profiling.html", ) self.output_sqlite_file = "sqlite:///" + os.path.join( self.output_dir, "experiences.db", ) self.sqlite_writer = get_buffer_writer( StorageConfig( name="debug_buffer", schema_type="experience", path=self.output_sqlite_file, storage_type="sql", batch_size=1, wrap_in_ray=False, ) )
[文档] async def prepare(self) -> None: # make sure models are started prepare_refs = [self.model.prepare.remote()] prepare_refs.extend(model.prepare.remote() for model in self.auxiliary_models) await asyncio.gather(*prepare_refs) await super().prepare()
[文档] async def debug(self) -> None: """Run the debug workflow.""" await self.prepare() tasks = await self.taskset.read_async(batch_size=1) task = tasks[0] self.logger.info(f"Start debugging task:\n{task.raw_task}") if not self.enable_profiling: status, exps = await self.run_task( task=task, batch_id="debug", repeat_times=1, run_id_base=0 ) else: from viztracer import VizTracer with VizTracer(output_file=self.output_profiling_file): status, exps = await self.run_task( task=task, batch_id="debug", repeat_times=1, run_id_base=0 ) if not status.ok and len(exps) == 0: exps = self.model_wrapper.extract_experience_from_history() self.logger.info(f"Debugging failed, extracting {len(exps)} experiences from history.") await self.sqlite_writer.write_async(exps) if status.ok: print(f"Task {task.task_id} completed successfully with metrics:\n{status.metrics}") else: self.logger.error(f"Task {task.task_id} failed with message: {status.message}") self.logger.info("Debugging completed.")